TensorRT-LLMs/tensorrt_llm/_torch/expert_statistic.py
Robin Kobus df41f220a2
[TRTLLM-8831][feat] Enable early exit with overlap scheduler (#8587)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-11-17 18:07:13 +01:00

102 lines
3.5 KiB
Python

import json
import os
import safetensors
import torch
from tensorrt_llm.logger import logger
class ExpertStatistic:
expert_statistic_obj = None
@staticmethod
def get():
return ExpertStatistic.expert_statistic_obj
@staticmethod
def create(rank_id: int):
# Enabled if EXPERT_STATISTIC_ITER_RANGE is set.
span = os.environ.get('EXPERT_STATISTIC_ITER_RANGE', None)
if span is None:
return
try:
start, stop = span.strip().split('-')
start, stop = int(start), int(stop)
except ValueError as e:
raise ValueError(str(e))
ExpertStatistic.expert_statistic_obj = ExpertStatistic(
rank_id, start, stop)
@staticmethod
def should_record() -> bool:
if ExpertStatistic.expert_statistic_obj is not None:
return ExpertStatistic.expert_statistic_obj._should_record
return False
@staticmethod
def set_iter(iter_id: int) -> None:
if ExpertStatistic.expert_statistic_obj is not None:
ExpertStatistic.expert_statistic_obj._set_iter(iter_id)
@staticmethod
def set_layer(layer_id: int) -> None:
if ExpertStatistic.expert_statistic_obj is not None:
ExpertStatistic.expert_statistic_obj._set_layer(layer_id)
@staticmethod
def maybe_add_info(expert_count: int,
token_selected_experts: torch.Tensor) -> None:
if ExpertStatistic.expert_statistic_obj is not None:
ExpertStatistic.expert_statistic_obj._maybe_add_info(
expert_count, token_selected_experts)
def __init__(self, rank_id: int, start: int, stop: int) -> None:
self.current_iter_id = None
self.current_layer = None
self.rank_id = rank_id
self.start = start
self.stop = stop
self._meta_info = None
self._records = {}
@property
def _should_record(self) -> bool:
return self.current_iter_id is not None and self.start <= self.current_iter_id < self.stop
def _set_iter(self, iter_id: int) -> None:
self.current_iter_id = iter_id
if iter_id == self.stop:
logger.info(
f'[ExpertStatistic] Rank={self.rank_id}, saving iter={iter_id}, start={self.start}, stop={self.stop}'
)
path = os.environ.get('EXPERT_STATISTIC_PATH', 'expert_statistic')
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
if self.rank_id == 0:
with open(f"{path}/meta_info.json", "w") as f:
json.dump(self._meta_info, f)
safetensors.torch.save_file(
self._records, f"{path}/rank{self.rank_id}.safetensors")
def _set_layer(self, layer: int) -> None:
self.current_layer = layer
def _maybe_add_info(self, expert_count: int,
token_selected_experts: torch.Tensor) -> None:
if not self._should_record:
return
if self._meta_info is None:
self._meta_info = {
"num_experts": expert_count,
"num_experts_per_token": token_selected_experts.size(-1)
}
counts = torch.bincount(token_selected_experts.flatten(),
minlength=expert_count)
key = f"{self.current_iter_id}_{self.current_layer}"
if key not in self._records:
self._records[key] = counts.cpu()
else:
self._records[key] += counts.cpu()