From 136aab5c5406a93fb1da33ce426d6c8acac8a18f Mon Sep 17 00:00:00 2001 From: HuiGao-NV Date: Sun, 27 Apr 2025 10:24:47 +0800 Subject: [PATCH] fix: Update num_of_ctx_tokens in iteration stats (#3785) * Update num_of_ctx_tokens in iteration stats * Revert not neccessary change of importing module --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 4f33d18c28..9248d1fee5 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -21,7 +21,7 @@ import torch from tensorrt_llm._utils import global_mpi_rank, nvtx_range from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats, IterationStats, KvCacheStats, - RequestType) + RequestType, StaticBatchingStats) from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet from tensorrt_llm.logger import logger @@ -494,11 +494,15 @@ class PyExecutor: def _get_init_iter_stats(self, num_new_active_requests, new_active_requests_queue_latency_ms): stats = IterationStats() - stats.timestamp = "" + stats.timestamp = datetime.datetime.now().strftime( + "%m-%d-%Y %H:%M:%S.%f") stats.num_new_active_requests = num_new_active_requests stats.num_active_requests = len(self.active_requests) stats.new_active_requests_queue_latency_ms = new_active_requests_queue_latency_ms + stats.inflight_batching_stats = InflightBatchingStats() + # staticBatchingStats is not used in pytorch path + stats.static_batching_stats = StaticBatchingStats() return stats def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, @@ -532,17 +536,17 @@ class PyExecutor: kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate stats.kv_cache_stats = kv_stats_to_save - model_stats = InflightBatchingStats() - model_stats.num_scheduled_requests = len( + stats.inflight_batching_stats.num_scheduled_requests = len( scheduled_batch.context_requests) + len( scheduled_batch.generation_requests) - model_stats.num_context_requests = len(scheduled_batch.context_requests) - model_stats.num_gen_requests = len(scheduled_batch.generation_requests) - model_stats.num_paused_requests = len(scheduled_batch.paused_requests) - model_stats.avg_num_decoded_tokens_per_iter = 0 - model_stats.num_ctx_tokens = 0 - model_stats.micro_batch_id = 0 - stats.inflight_batching_stats = model_stats + stats.inflight_batching_stats.num_context_requests = len( + scheduled_batch.context_requests) + stats.inflight_batching_stats.num_gen_requests = len( + scheduled_batch.generation_requests) + stats.inflight_batching_stats.num_paused_requests = len( + scheduled_batch.paused_requests) + stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0 + stats.inflight_batching_stats.micro_batch_id = 0 return stats def _append_iter_stats(self, stats): @@ -624,6 +628,10 @@ class PyExecutor: decoder_state = self._forward_step_inter_pp( scheduled_batch) + if self.enable_iter_perf_stats: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] + batch_state = BatchStatePP( decoder_state=decoder_state, iter_start_time=iter_start_time, @@ -717,6 +725,9 @@ class PyExecutor: scheduled_batch, batch_outputs) self._update_request_states(scheduled_batch) + if self.enable_iter_perf_stats: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] batch_state = BatchStatePP( decoder_state=decoder_state, iter_start_time=iter_start_time, @@ -887,6 +898,8 @@ class PyExecutor: self._gather_dp_requests_num() if self.enable_iter_perf_stats: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] self._process_iter_stats( finished_requests, BatchState(decoder_state=DecoderState( @@ -1040,6 +1053,10 @@ class PyExecutor: if r.get_context_remaining_length() == 0 ] + if self.enable_iter_perf_stats: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] + self.previous_batch = BatchState( decoder_state=decoder_state, iter_start_time=iter_start_time,