mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Update num_of_ctx_tokens in iteration stats (#3785)
* Update num_of_ctx_tokens in iteration stats * Revert not neccessary change of importing module
This commit is contained in:
parent
a4b483b969
commit
136aab5c54
@ -21,7 +21,7 @@ import torch
|
||||
from tensorrt_llm._utils import global_mpi_rank, nvtx_range
|
||||
from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats,
|
||||
IterationStats, KvCacheStats,
|
||||
RequestType)
|
||||
RequestType, StaticBatchingStats)
|
||||
from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -494,11 +494,15 @@ class PyExecutor:
|
||||
def _get_init_iter_stats(self, num_new_active_requests,
|
||||
new_active_requests_queue_latency_ms):
|
||||
stats = IterationStats()
|
||||
stats.timestamp = ""
|
||||
stats.timestamp = datetime.datetime.now().strftime(
|
||||
"%m-%d-%Y %H:%M:%S.%f")
|
||||
|
||||
stats.num_new_active_requests = num_new_active_requests
|
||||
stats.num_active_requests = len(self.active_requests)
|
||||
stats.new_active_requests_queue_latency_ms = new_active_requests_queue_latency_ms
|
||||
stats.inflight_batching_stats = InflightBatchingStats()
|
||||
# staticBatchingStats is not used in pytorch path
|
||||
stats.static_batching_stats = StaticBatchingStats()
|
||||
return stats
|
||||
|
||||
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
|
||||
@ -532,17 +536,17 @@ class PyExecutor:
|
||||
kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate
|
||||
stats.kv_cache_stats = kv_stats_to_save
|
||||
|
||||
model_stats = InflightBatchingStats()
|
||||
model_stats.num_scheduled_requests = len(
|
||||
stats.inflight_batching_stats.num_scheduled_requests = len(
|
||||
scheduled_batch.context_requests) + len(
|
||||
scheduled_batch.generation_requests)
|
||||
model_stats.num_context_requests = len(scheduled_batch.context_requests)
|
||||
model_stats.num_gen_requests = len(scheduled_batch.generation_requests)
|
||||
model_stats.num_paused_requests = len(scheduled_batch.paused_requests)
|
||||
model_stats.avg_num_decoded_tokens_per_iter = 0
|
||||
model_stats.num_ctx_tokens = 0
|
||||
model_stats.micro_batch_id = 0
|
||||
stats.inflight_batching_stats = model_stats
|
||||
stats.inflight_batching_stats.num_context_requests = len(
|
||||
scheduled_batch.context_requests)
|
||||
stats.inflight_batching_stats.num_gen_requests = len(
|
||||
scheduled_batch.generation_requests)
|
||||
stats.inflight_batching_stats.num_paused_requests = len(
|
||||
scheduled_batch.paused_requests)
|
||||
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
|
||||
stats.inflight_batching_stats.micro_batch_id = 0
|
||||
return stats
|
||||
|
||||
def _append_iter_stats(self, stats):
|
||||
@ -624,6 +628,10 @@ class PyExecutor:
|
||||
decoder_state = self._forward_step_inter_pp(
|
||||
scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
|
||||
batch_state = BatchStatePP(
|
||||
decoder_state=decoder_state,
|
||||
iter_start_time=iter_start_time,
|
||||
@ -717,6 +725,9 @@ class PyExecutor:
|
||||
scheduled_batch, batch_outputs)
|
||||
self._update_request_states(scheduled_batch)
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
batch_state = BatchStatePP(
|
||||
decoder_state=decoder_state,
|
||||
iter_start_time=iter_start_time,
|
||||
@ -887,6 +898,8 @@ class PyExecutor:
|
||||
self._gather_dp_requests_num()
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
self._process_iter_stats(
|
||||
finished_requests,
|
||||
BatchState(decoder_state=DecoderState(
|
||||
@ -1040,6 +1053,10 @@ class PyExecutor:
|
||||
if r.get_context_remaining_length() == 0
|
||||
]
|
||||
|
||||
if self.enable_iter_perf_stats:
|
||||
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
||||
'num_ctx_tokens']
|
||||
|
||||
self.previous_batch = BatchState(
|
||||
decoder_state=decoder_state,
|
||||
iter_start_time=iter_start_time,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user