TensorRT-LLMs/tests/dump_checkpoint_stats.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

30 lines
903 B
Python

import json
import os
import sys
import safetensors
def dump_stats(ckpt_dir):
config = None
with open(os.path.join(ckpt_dir, "config.json")) as c:
config = json.load(c)
tp_size = config['mapping']['tp_size']
pp_size = config['mapping']['pp_size']
world_size = tp_size * pp_size
for rank in range(world_size):
with safetensors.safe_open(os.path.join(ckpt_dir,
f'rank{rank}.safetensors'),
framework='pt',
device='cpu') as f:
# import pdb; pdb.set_trace()
for key in f.keys():
tensor = f.get_tensor(key)
print(
f"rank-{rank}:{key}, shape:{list(tensor.shape)}, max:{tensor.max().item()}, min:{tensor.min().item()}"
)
return
dump_stats(sys.argv[1])