fix: Update num_of_ctx_tokens in iteration stats (#3785)

* Update num_of_ctx_tokens in iteration stats
* Revert not neccessary change of importing module
This commit is contained in:
HuiGao-NV 2025-04-27 10:24:47 +08:00 committed by GitHub
parent a4b483b969
commit 136aab5c54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,7 +21,7 @@ import torch
from tensorrt_llm._utils import global_mpi_rank, nvtx_range
from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats,
IterationStats, KvCacheStats,
RequestType)
RequestType, StaticBatchingStats)
from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet
from tensorrt_llm.logger import logger
@ -494,11 +494,15 @@ class PyExecutor:
def _get_init_iter_stats(self, num_new_active_requests,
new_active_requests_queue_latency_ms):
stats = IterationStats()
stats.timestamp = ""
stats.timestamp = datetime.datetime.now().strftime(
"%m-%d-%Y %H:%M:%S.%f")
stats.num_new_active_requests = num_new_active_requests
stats.num_active_requests = len(self.active_requests)
stats.new_active_requests_queue_latency_ms = new_active_requests_queue_latency_ms
stats.inflight_batching_stats = InflightBatchingStats()
# staticBatchingStats is not used in pytorch path
stats.static_batching_stats = StaticBatchingStats()
return stats
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
@ -532,17 +536,17 @@ class PyExecutor:
kv_stats_to_save.cache_hit_rate = kv_stats.cache_hit_rate
stats.kv_cache_stats = kv_stats_to_save
model_stats = InflightBatchingStats()
model_stats.num_scheduled_requests = len(
stats.inflight_batching_stats.num_scheduled_requests = len(
scheduled_batch.context_requests) + len(
scheduled_batch.generation_requests)
model_stats.num_context_requests = len(scheduled_batch.context_requests)
model_stats.num_gen_requests = len(scheduled_batch.generation_requests)
model_stats.num_paused_requests = len(scheduled_batch.paused_requests)
model_stats.avg_num_decoded_tokens_per_iter = 0
model_stats.num_ctx_tokens = 0
model_stats.micro_batch_id = 0
stats.inflight_batching_stats = model_stats
stats.inflight_batching_stats.num_context_requests = len(
scheduled_batch.context_requests)
stats.inflight_batching_stats.num_gen_requests = len(
scheduled_batch.generation_requests)
stats.inflight_batching_stats.num_paused_requests = len(
scheduled_batch.paused_requests)
stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0
stats.inflight_batching_stats.micro_batch_id = 0
return stats
def _append_iter_stats(self, stats):
@ -624,6 +628,10 @@ class PyExecutor:
decoder_state = self._forward_step_inter_pp(
scheduled_batch)
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
batch_state = BatchStatePP(
decoder_state=decoder_state,
iter_start_time=iter_start_time,
@ -717,6 +725,9 @@ class PyExecutor:
scheduled_batch, batch_outputs)
self._update_request_states(scheduled_batch)
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
batch_state = BatchStatePP(
decoder_state=decoder_state,
iter_start_time=iter_start_time,
@ -887,6 +898,8 @@ class PyExecutor:
self._gather_dp_requests_num()
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
self._process_iter_stats(
finished_requests,
BatchState(decoder_state=DecoderState(
@ -1040,6 +1053,10 @@ class PyExecutor:
if r.get_context_remaining_length() == 0
]
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
'num_ctx_tokens']
self.previous_batch = BatchState(
decoder_state=decoder_state,
iter_start_time=iter_start_time,