mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-19 17:25:17 +08:00
Chore: refine shutdown signal of PyExecutor (#4614)
* refine shutdown signal of PyExecutor Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * clean Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> * fix ci Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com> --------- Signed-off-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
2fee408536
commit
4a81991b65
@ -48,13 +48,7 @@ PROFILE_RECORD_GC_ENV_VAR_NAME = "TLLM_PROFILE_RECORD_GC"
|
||||
# Set to a path to save detailed tracing of PyTorch operations.
|
||||
PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE"
|
||||
|
||||
|
||||
def _is_executor_request(req_queue_item) -> bool:
|
||||
return isinstance(req_queue_item, tuple)
|
||||
|
||||
|
||||
def _is_cancel_request(req_queue_item) -> bool:
|
||||
return isinstance(req_queue_item, int)
|
||||
SHUTDOWN_REQUEST_ID = -1
|
||||
|
||||
|
||||
def _get_from_request_queue(request_queue, timeout: datetime.timedelta,
|
||||
@ -71,8 +65,8 @@ def _get_from_request_queue(request_queue, timeout: datetime.timedelta,
|
||||
while req_count < max_req_count:
|
||||
queue_item = request_queue.get_nowait()
|
||||
items.append(queue_item)
|
||||
if _is_executor_request(queue_item):
|
||||
# if it is request, (Not finish signal or cancel signal)
|
||||
if queue_item[0] != SHUTDOWN_REQUEST_ID:
|
||||
# if it is request, not shutdown signal
|
||||
req_count += 1
|
||||
except queue.Empty:
|
||||
pass
|
||||
@ -350,7 +344,7 @@ class PyExecutor:
|
||||
"""
|
||||
try:
|
||||
self.enqueue_lock.acquire()
|
||||
self.request_queue.put(None)
|
||||
self.request_queue.put((SHUTDOWN_REQUEST_ID, ))
|
||||
self.active = False
|
||||
finally:
|
||||
self.enqueue_lock.release()
|
||||
@ -1243,8 +1237,8 @@ class PyExecutor:
|
||||
|
||||
self.has_context_request = False
|
||||
new_requests_cur_rank = []
|
||||
if new_requests != [] and new_requests[
|
||||
0] != None and self.expected_num_active_requests > all_ranks_num_active_requests[
|
||||
if new_requests != [] and new_requests[0][
|
||||
0] != SHUTDOWN_REQUEST_ID and self.expected_num_active_requests > all_ranks_num_active_requests[
|
||||
self.dist.tp_rank]:
|
||||
# Balance context tokens across ranks
|
||||
HeapVal = namedtuple(
|
||||
@ -1298,7 +1292,7 @@ class PyExecutor:
|
||||
self.num_fetch_requests_cur_rank = self.num_fetch_requests_cur_rank + len(
|
||||
new_requests_cur_rank)
|
||||
|
||||
if len(new_requests) == 1 and new_requests[0] is None:
|
||||
if len(new_requests) == 1 and new_requests[0][0] == SHUTDOWN_REQUEST_ID:
|
||||
new_requests_cur_rank = new_requests
|
||||
return new_requests_cur_rank
|
||||
|
||||
@ -1313,15 +1307,12 @@ class PyExecutor:
|
||||
|
||||
def _merge_tp_requests(self, new_requests: List[ExecutorRequest]):
|
||||
for request in new_requests:
|
||||
if request is None:
|
||||
if request[0] == SHUTDOWN_REQUEST_ID:
|
||||
return True
|
||||
for req_item in new_requests:
|
||||
if _is_executor_request(req_item):
|
||||
req_id, exe_req = req_item
|
||||
req = executor_request_to_llm_request(req_id, exe_req)
|
||||
self.active_requests.append(req)
|
||||
elif _is_cancel_request(req_item):
|
||||
self.canceled_req_ids.insert(req_item)
|
||||
req_id, exe_req = req_item
|
||||
req = executor_request_to_llm_request(req_id, exe_req)
|
||||
self.active_requests.append(req)
|
||||
|
||||
return False
|
||||
|
||||
@ -1348,7 +1339,7 @@ class PyExecutor:
|
||||
"""
|
||||
req_id_to_obj = {}
|
||||
for item in requests:
|
||||
if item is None:
|
||||
if item[0] == SHUTDOWN_REQUEST_ID:
|
||||
continue
|
||||
req_id, req = item[:2]
|
||||
obj = getattr(req, attribute_name, None)
|
||||
@ -1362,7 +1353,7 @@ class PyExecutor:
|
||||
to each request.
|
||||
"""
|
||||
for item in requests:
|
||||
if item is None:
|
||||
if item[0] == SHUTDOWN_REQUEST_ID:
|
||||
continue
|
||||
req_id, req = item[:2]
|
||||
py_obj = py_request_objects.get(req_id)
|
||||
@ -1410,61 +1401,57 @@ class PyExecutor:
|
||||
def _merge_star_attention_requests(self,
|
||||
new_requests: List[ExecutorRequest]):
|
||||
for request in new_requests:
|
||||
if request is None:
|
||||
if request[0] == SHUTDOWN_REQUEST_ID:
|
||||
return True
|
||||
for req_item in new_requests:
|
||||
if _is_executor_request(req_item):
|
||||
req_id, exe_req, query_token_ids = req_item
|
||||
ctx_len0 = len(exe_req.input_token_ids)
|
||||
ctx_blocks, position_blocks, last_block_padding_num = [
|
||||
exe_req.input_token_ids
|
||||
], [[i for i in range(ctx_len0)]], 0
|
||||
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
|
||||
exe_req.input_token_ids)
|
||||
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
|
||||
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
|
||||
position_blocks[-1] = position_blocks[
|
||||
-1][:-last_block_padding_num]
|
||||
#if has query
|
||||
if query_token_ids:
|
||||
ctx_blocks.append(query_token_ids)
|
||||
position_blocks.append([
|
||||
i for i in range(ctx_len0, ctx_len0 +
|
||||
len(query_token_ids))
|
||||
])
|
||||
req_id, exe_req, query_token_ids = req_item
|
||||
ctx_len0 = len(exe_req.input_token_ids)
|
||||
ctx_blocks, position_blocks, last_block_padding_num = [
|
||||
exe_req.input_token_ids
|
||||
], [[i for i in range(ctx_len0)]], 0
|
||||
ctx_blocks, position_blocks, last_block_padding_num = self._partition_context(
|
||||
exe_req.input_token_ids)
|
||||
if self.dist.cp_rank == self.dist.cp_size - 1 and last_block_padding_num > 0:
|
||||
ctx_blocks[-1] = ctx_blocks[-1][:-last_block_padding_num]
|
||||
position_blocks[-1] = position_blocks[
|
||||
-1][:-last_block_padding_num]
|
||||
#if has query
|
||||
if query_token_ids:
|
||||
ctx_blocks.append(query_token_ids)
|
||||
position_blocks.append([
|
||||
i for i in range(ctx_len0, ctx_len0 + len(query_token_ids))
|
||||
])
|
||||
|
||||
# insert the dummy block to align the number of ctx iterations of each rank
|
||||
block_size = self.dist.cp_config['block_size']
|
||||
total_blocks = (ctx_len0 + block_size - 1) // block_size
|
||||
num_blocks_per_rank = (
|
||||
total_blocks + self.dist.cp_size -
|
||||
1) // self.dist.cp_size + 1 # 1 for query block
|
||||
if len(ctx_blocks) == num_blocks_per_rank:
|
||||
ctx_blocks.insert(1, [])
|
||||
position_blocks.insert(1, [])
|
||||
elif len(ctx_blocks) == num_blocks_per_rank + 1:
|
||||
# anchor + ctx_blocks + qry_block
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
|
||||
)
|
||||
assert False, f'invalid context partition'
|
||||
# insert the dummy block to align the number of ctx iterations of each rank
|
||||
block_size = self.dist.cp_config['block_size']
|
||||
total_blocks = (ctx_len0 + block_size - 1) // block_size
|
||||
num_blocks_per_rank = (
|
||||
total_blocks + self.dist.cp_size -
|
||||
1) // self.dist.cp_size + 1 # 1 for query block
|
||||
if len(ctx_blocks) == num_blocks_per_rank:
|
||||
ctx_blocks.insert(1, [])
|
||||
position_blocks.insert(1, [])
|
||||
elif len(ctx_blocks) == num_blocks_per_rank + 1:
|
||||
# anchor + ctx_blocks + qry_block
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f'rank = {self.dist.cp_rank}, len(ctx_blocks) = {len(ctx_blocks) }, num_blocks_per_rank = {num_blocks_per_rank}'
|
||||
)
|
||||
assert False, f'invalid context partition'
|
||||
|
||||
# fake data for scheduler
|
||||
ctx_blocks_list = [0] * (block_size +
|
||||
self.dist.cp_config['cp_anchor_size'])
|
||||
# fake data for scheduler
|
||||
ctx_blocks_list = [0] * (block_size +
|
||||
self.dist.cp_config['cp_anchor_size'])
|
||||
|
||||
req = executor_request_to_llm_request(req_id, exe_req,
|
||||
ctx_blocks_list)
|
||||
req.gen_iters = 0
|
||||
req.ctx_iters = 0
|
||||
req.ctx_blocks = ctx_blocks
|
||||
req.ctx_position_blocks = position_blocks
|
||||
req.query_id = query_token_ids
|
||||
self.active_requests.append(req)
|
||||
elif _is_cancel_request(req_item):
|
||||
self.canceled_req_ids.insert(req_item)
|
||||
req = executor_request_to_llm_request(req_id, exe_req,
|
||||
ctx_blocks_list)
|
||||
req.gen_iters = 0
|
||||
req.ctx_iters = 0
|
||||
req.ctx_blocks = ctx_blocks
|
||||
req.ctx_position_blocks = position_blocks
|
||||
req.query_id = query_token_ids
|
||||
self.active_requests.append(req)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user