mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Setup dist for AutoTuner in Layerwise benchmarking. (#10534)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
6fcd4e7099
commit
c5331e6dbb
@ -10,10 +10,11 @@ import torch
|
||||
import yaml
|
||||
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
|
||||
from tensorrt_llm._torch.distributed import MPIDist, TorchDist
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cutlass import CutlassFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.interface import AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.multi_stream_utils import with_multi_stream
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm._utils import local_mpi_rank, mpi_disabled, mpi_rank, mpi_world_size
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.tools.layer_wise_benchmarks import BalanceMethod, get_runner_cls, mark_ranges
|
||||
|
||||
@ -173,6 +174,8 @@ run_pack = runner.create_run_pack(
|
||||
)
|
||||
if args.enable_autotuner:
|
||||
cache_path = os.getenv("TLLM_AUTOTUNER_CACHE_PATH") or None
|
||||
dist = TorchDist(mapping=mapping) if mpi_disabled() else MPIDist(mapping=mapping)
|
||||
AutoTuner.get().setup_distributed_state(mapping, dist)
|
||||
with autotune(cache_path=cache_path):
|
||||
run_pack()
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user