mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
744 lines
27 KiB
Python
744 lines
27 KiB
Python
import itertools
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import tempfile
|
|
from typing import Any, List
|
|
|
|
import cloudpickle
|
|
import pytest
|
|
import torch
|
|
from mpi4py import MPI
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm._torch.autotuner as autotuner
|
|
from tensorrt_llm._torch.autotuner import (AutoTuner, DistributedTuningStrategy,
|
|
DynamicDim, DynamicTensorSpec,
|
|
FakeTensor, OptimizationProfile,
|
|
StaticDim, TunableRunner,
|
|
TuningConfig, autotune)
|
|
from tensorrt_llm._torch.distributed.communicator import MPIDist, TorchDist
|
|
from tensorrt_llm._torch.utils import (get_power_of_2_num_tokens_buckets,
|
|
next_positive_power_of_2)
|
|
from tensorrt_llm._utils import mpi_disabled
|
|
from tensorrt_llm.bindings.internal.runtime import delay_kernel
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.mapping import Mapping
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
|
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
|
MPI.pickle.__init__(
|
|
cloudpickle.dumps,
|
|
cloudpickle.loads,
|
|
pickle.HIGHEST_PROTOCOL,
|
|
)
|
|
|
|
# needed since we reuse the mpi executor pool, first test running will leak a thread
|
|
pytestmark = pytest.mark.threadleak(enabled=False)
|
|
|
|
|
|
def test_multi_dynamic_dims():
|
|
tuner = autotuner.AutoTuner()
|
|
x = torch.rand([5, 1024])
|
|
w = torch.rand([7, 19])
|
|
dynamic_tensor_specs = (
|
|
DynamicTensorSpec(0, 0, [1, 3, 5]),
|
|
DynamicTensorSpec(0, 1, [16, 24, 1024]),
|
|
DynamicTensorSpec(1, 1, [3, 7, 9], lambda x: x // 2),
|
|
)
|
|
|
|
profiles = tuner._optimization_profiles(
|
|
tuning_config=TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs),
|
|
inputs=[x, w])
|
|
# choice(0, 0) * choice(0, 1) * choice(1, 1)
|
|
# 3 * 3 * 3 = 27, because 19 is mapped to 9 and already inside the bucket
|
|
assert len(profiles) == 27
|
|
sample_0 = OptimizationProfile(shapes=[[
|
|
DynamicDim(min=1, opt=1, max=3),
|
|
DynamicDim(min=16, opt=16, max=24)
|
|
], [StaticDim(val=7), DynamicDim(min=3, opt=3, max=7)]])
|
|
sample_26 = OptimizationProfile(shapes=[[
|
|
DynamicDim(min=5, opt=5, max=float('inf')),
|
|
DynamicDim(min=1024, opt=1024, max=float('inf'))
|
|
], [StaticDim(
|
|
val=7), DynamicDim(min=9, opt=9, max=float('inf'))]])
|
|
|
|
assert sample_0 == profiles[0]
|
|
assert sample_26 == profiles[-1]
|
|
|
|
|
|
# For cache testing
|
|
"""
|
|
tactic 0 is better when x.shape[0] <= M // 2
|
|
tactic 1 is better when x.shape[0] > M // 2
|
|
"""
|
|
M = 32
|
|
|
|
|
|
# add sleep to simulate bad perf
|
|
def gemm_0(x, w):
|
|
if x.shape[0] > M // 2:
|
|
delay_kernel(100, torch.cuda.current_stream())
|
|
return x @ w
|
|
|
|
|
|
def gemm_1(x, w):
|
|
if x.shape[0] <= M // 2:
|
|
delay_kernel(100, torch.cuda.current_stream())
|
|
return x @ w
|
|
|
|
|
|
def gemm_fallback(x, w) -> torch.Tensor:
|
|
# always the slowest
|
|
delay_kernel(500, torch.cuda.current_stream())
|
|
return x @ w
|
|
|
|
|
|
def check_gemm_tactic_valid(tactic: int, m: int) -> bool:
|
|
# TODO: CI is not stable for this test. delay_kernel can not guarantee the profiling result.
|
|
# We need to find a more determinist way to test this.
|
|
if m <= M // 2:
|
|
if tactic != 0:
|
|
logger.warning(
|
|
f"Expect tactic 0 but got {tactic} when m ({m}) is small.")
|
|
elif m <= M:
|
|
if tactic != 1:
|
|
logger.warning(
|
|
f"Expect tactic 1 but got {tactic} when m ({m}) is large.")
|
|
else:
|
|
if tactic != -1:
|
|
logger.warning(
|
|
f"Expect fallback tactic (-1) but got {tactic} when m ({m}) > {M}."
|
|
)
|
|
|
|
|
|
class GemmRunner(TunableRunner):
|
|
|
|
def get_valid_tactics(self, inputs: List[FakeTensor],
|
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
|
# The simulated delay is not deterministic, so we need to return specific tactics here
|
|
return [-1, 0, 1]
|
|
|
|
def forward(self,
|
|
/,
|
|
inputs: List[torch.Tensor],
|
|
*,
|
|
tactic: int = -1,
|
|
**kwargs) -> torch.Tensor:
|
|
assert tactic in [-1, 0, 1]
|
|
return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs)
|
|
|
|
|
|
@torch.library.custom_op("autotuner_test::get_best_gemm_tactic",
|
|
mutates_args=())
|
|
def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
runners = [GemmRunner()]
|
|
tuner = AutoTuner.get()
|
|
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
|
input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
|
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
|
runner, tactic = tuner.choose_one(
|
|
"autotuner_test::get_best_gemm_tactic",
|
|
runners,
|
|
tuning_config,
|
|
[x, w],
|
|
)
|
|
return torch.tensor(tactic)
|
|
|
|
|
|
@get_best_gemm_tactic.register_fake
|
|
def _(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
return torch.empty(1)
|
|
|
|
|
|
def test_autotuner_cache_basic():
|
|
w = torch.randn(64, 128)
|
|
|
|
# tuning with largest M
|
|
AutoTuner.get().clear_cache()
|
|
with autotune():
|
|
torch.ops.autotuner_test.get_best_gemm_tactic(torch.randn(M, 64), w)
|
|
|
|
# This tests the logic of print_profiling_cache and print_statistics
|
|
AutoTuner.get().print_profiling_cache()
|
|
AutoTuner.get().print_statistics()
|
|
|
|
m = M * 2
|
|
while m >= 1:
|
|
best_tactic = torch.ops.autotuner_test.get_best_gemm_tactic(
|
|
torch.randn(m, 64), w)
|
|
check_gemm_tactic_valid(best_tactic, m)
|
|
m //= 2
|
|
|
|
|
|
def test_autotuner_try_block():
|
|
|
|
class PartialCrashedRunner(TunableRunner):
|
|
|
|
def get_valid_tactics(self, inputs: List[FakeTensor],
|
|
profile: OptimizationProfile,
|
|
**kwargs) -> List[int]:
|
|
return [-1, 0, 1]
|
|
|
|
def forward(self,
|
|
/,
|
|
inputs: List[torch.Tensor],
|
|
*,
|
|
tactic: int = -1) -> torch.Tensor:
|
|
assert tactic in [-1, 0, 1]
|
|
if tactic == 1:
|
|
raise Exception(
|
|
"For profiling try block test: Tactic 1 is not suitable. Crash happens."
|
|
)
|
|
return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs)
|
|
|
|
x, w = torch.randn(M, 64), torch.randn(64, 128)
|
|
runners = [PartialCrashedRunner()]
|
|
tuner = AutoTuner.get()
|
|
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
|
input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
|
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
|
with autotune():
|
|
runner, tactic = tuner.choose_one("test_autotuner_try_block", runners,
|
|
tuning_config, [x, w])
|
|
|
|
m = M // 2
|
|
while m >= 1:
|
|
_, tactic = tuner.choose_one("test_autotuner_try_block", runners,
|
|
tuning_config, [torch.randn(m, 64), w])
|
|
assert tactic in [
|
|
-1, 0
|
|
], f"Expect only tactic -1, 0 being chosen, but got tactic {tactic}."
|
|
m //= 2
|
|
|
|
|
|
@torch.library.custom_op("autotuner_test::recursive_get_best_gemm_tactic",
|
|
mutates_args=())
|
|
def recursive_get_best_gemm_tactic(x: torch.Tensor, w1: torch.Tensor,
|
|
w2: torch.Tensor) -> torch.Tensor:
|
|
# Only the first custom_op is tuned, the second one uses the tuned result in cache
|
|
tactic_1 = get_best_gemm_tactic(x, w1)
|
|
tactic_2 = get_best_gemm_tactic(x, w2)
|
|
return torch.stack([tactic_1, tactic_2])
|
|
|
|
|
|
@recursive_get_best_gemm_tactic.register_fake
|
|
def _(x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor:
|
|
return torch.empty(2)
|
|
|
|
|
|
def test_recursive_autotuner():
|
|
x, w1, w2 = torch.randn(M, 64), torch.randn(64, 128), torch.randn(64, 128)
|
|
AutoTuner.get().clear_cache()
|
|
with autotune():
|
|
torch.ops.autotuner_test.recursive_get_best_gemm_tactic(
|
|
torch.randn(M, 64), w1, w2)
|
|
|
|
m = M * 2
|
|
while m >= 1:
|
|
t1, t2 = torch.ops.autotuner_test.recursive_get_best_gemm_tactic(
|
|
torch.randn(m, 64), w1, w2)
|
|
check_gemm_tactic_valid(t1, m)
|
|
check_gemm_tactic_valid(t2, m)
|
|
m //= 2
|
|
|
|
|
|
class GemmRunnerWithAttributes(TunableRunner):
|
|
|
|
def __init__(self, block_size: int, num_warps: int):
|
|
self.block_size = block_size
|
|
self.num_warps = num_warps
|
|
|
|
def get_valid_tactics(self, inputs: List[FakeTensor],
|
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
|
return [-1, 0, 1]
|
|
|
|
def forward(self,
|
|
/,
|
|
inputs: List[torch.Tensor],
|
|
*,
|
|
tactic: int = -1) -> torch.Tensor:
|
|
assert tactic in [-1, 0, 1]
|
|
return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs)
|
|
|
|
|
|
def test_multiple_runners_different_attributes():
|
|
"""Test that runners with different attributes get different cache entries"""
|
|
x, w = torch.randn(16, 64), torch.randn(64, 128)
|
|
|
|
# Create runners with different attributes
|
|
runner_0 = GemmRunnerWithAttributes(block_size=128, num_warps=4)
|
|
runner_1 = GemmRunnerWithAttributes(block_size=256, num_warps=8)
|
|
runners = [runner_0, runner_1]
|
|
|
|
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
|
input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
|
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
|
|
|
# Do tuning
|
|
with autotune():
|
|
tuner = AutoTuner.get()
|
|
runner_a, tactic_a = tuner.choose_one("test_multiple_runners", runners,
|
|
tuning_config, [x, w])
|
|
|
|
# Verify different cache keys are generated
|
|
shapes = (x.shape, w.shape)
|
|
cache_key_0 = tuner.profiling_cache.get_cache_key(
|
|
custom_op="test_multiple_runners",
|
|
input_shapes=shapes,
|
|
runner=runner_0,
|
|
tuning_config=tuning_config,
|
|
)
|
|
cache_key_1 = tuner.profiling_cache.get_cache_key(
|
|
custom_op="test_multiple_runners",
|
|
input_shapes=shapes,
|
|
runner=runner_1,
|
|
tuning_config=tuning_config,
|
|
)
|
|
|
|
assert cache_key_0 != cache_key_1, "Runners with different attributes should have different cache keys"
|
|
|
|
|
|
def test_multiple_dynamic_shapes_cache():
|
|
"""Test that different dynamic shape combinations are properly cached"""
|
|
w = torch.randn(64, 128)
|
|
runners = [GemmRunner()]
|
|
|
|
# Define dynamic ranges for both dimensions
|
|
tuning_config = TuningConfig(dynamic_tensor_specs=(
|
|
DynamicTensorSpec(input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=(3, 4, 5),
|
|
map_to_tuning_buckets=lambda x: x),
|
|
DynamicTensorSpec(input_idx=1,
|
|
dim_idx=1,
|
|
gen_tuning_buckets=(64, 128, 256, 512),
|
|
map_to_tuning_buckets=lambda x: x),
|
|
), )
|
|
|
|
# Do tuning with a sample input
|
|
x = torch.randn(3, 64)
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
cache_path = os.path.join(temp_dir.name,
|
|
"test_multiple_dynamic_shapes.json")
|
|
with autotune(cache_path=cache_path):
|
|
tuner = AutoTuner.get()
|
|
runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes",
|
|
runners, tuning_config, [x, w])
|
|
|
|
cache_entries = tuner.profiling_cache.get_specific_custom_op(
|
|
"test_multiple_dynamic_shapes")
|
|
assert len(cache_entries) == 12, \
|
|
f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"
|
|
# Verify cache size - should have 12 entries (3x4 combinations)
|
|
# We also test the cache serialization and deserialization here.
|
|
AutoTuner.get().profiling_cache.clear()
|
|
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
|
|
cache_entries = tuner.profiling_cache.get_specific_custom_op(
|
|
"test_multiple_dynamic_shapes")
|
|
|
|
assert len(cache_entries) == 12, \
|
|
f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}"
|
|
|
|
|
|
class GemmRunnerComplexTuningConfigs(TunableRunner):
|
|
|
|
# test serialization of different types of tactics
|
|
valid_tactic_ids = [-1, 0, 1]
|
|
valid_tile_sizes = [(128, 128), (256, 256)]
|
|
valid_cluster_sizes = [[1, 1, 1], [2, 2, 1]]
|
|
|
|
tune_max_num_tokens = 32
|
|
|
|
def get_valid_tactics(
|
|
self,
|
|
inputs: List[FakeTensor],
|
|
profile: OptimizationProfile,
|
|
**kwargs,
|
|
) -> List[Any]:
|
|
# During the tuning process, we verify if the tuning config behaves as expected
|
|
assert inputs[0].shape[0] <= self.tune_max_num_tokens, \
|
|
f"Input shape {inputs[0].shape[0]} is larger than the max num tokens {self.tune_max_num_tokens}"
|
|
|
|
assert inputs[0][-1, 0] == inputs[0].shape[0], \
|
|
f"Input shape {inputs[0].shape[0]} is not set through the pre_hook correctly"
|
|
|
|
return [{
|
|
"int_tactic_id": tactic_id,
|
|
"tuple_tile_size": tile_size,
|
|
"list_cluster_size": cluster_size,
|
|
} for tactic_id, tile_size, cluster_size in itertools.product(
|
|
self.valid_tactic_ids,
|
|
self.valid_tile_sizes,
|
|
self.valid_cluster_sizes,
|
|
)]
|
|
|
|
def forward(
|
|
self,
|
|
/,
|
|
inputs: List[torch.Tensor],
|
|
*,
|
|
tactic: Any = -1,
|
|
) -> torch.Tensor:
|
|
# Notice that in fallback case tactic is -1
|
|
if tactic == -1:
|
|
# assign default configs for fallback case
|
|
tactic_id, tile_size, cluster_size = -1, (128, 256), [1, 1, 1]
|
|
else:
|
|
tactic_id, tile_size, cluster_size = tactic[
|
|
"int_tactic_id"], tactic["tuple_tile_size"], tactic[
|
|
"list_cluster_size"]
|
|
|
|
assert isinstance(tactic_id, int) and tactic_id in self.valid_tactic_ids
|
|
assert isinstance(tile_size, tuple) and len(tile_size) == 2 \
|
|
and tile_size in self.valid_tile_sizes
|
|
assert isinstance(cluster_size, list) and len(cluster_size) == 3 \
|
|
and cluster_size in self.valid_cluster_sizes
|
|
return [gemm_0, gemm_1, gemm_fallback][tactic_id](*inputs)
|
|
|
|
@staticmethod
|
|
def inputs_pre_hook(inputs: List[torch.Tensor]):
|
|
# always set the first element to be the number of tokens in x
|
|
x, w = inputs
|
|
x_hooked = torch.zeros_like(x)
|
|
x_hooked[-1, 0] = x.shape[0]
|
|
return [x_hooked, w]
|
|
|
|
|
|
def test_autotuner_tuning_configs():
|
|
runner_0 = GemmRunnerComplexTuningConfigs()
|
|
runners = [runner_0]
|
|
x, w = torch.randn(64, 64), torch.randn(64, 128)
|
|
tuning_config = TuningConfig(
|
|
dynamic_tensor_specs=(DynamicTensorSpec(
|
|
input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
|
map_to_tuning_buckets=next_positive_power_of_2,
|
|
), ),
|
|
# Test if the number of tuning tokens is clipped to 32
|
|
tune_max_num_tokens=GemmRunnerComplexTuningConfigs.tune_max_num_tokens,
|
|
inputs_pre_hook=GemmRunnerComplexTuningConfigs.inputs_pre_hook,
|
|
use_cold_l2_cache=True,
|
|
use_cuda_graph=False,
|
|
)
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
cache_path = os.path.join(temp_dir.name,
|
|
"test_autotuner_tactic_configs.json")
|
|
with autotune(cache_path=cache_path):
|
|
tuner = AutoTuner.get()
|
|
runner, best_tactic = tuner.choose_one("test_autotuner_tactic_configs",
|
|
runners, tuning_config, [x, w])
|
|
|
|
runner_0([x, w], tactic=best_tactic)
|
|
|
|
# Test if the tactic can be loaded from cache correctly
|
|
AutoTuner.get().profiling_cache.clear()
|
|
AutoTuner.get().profiling_cache.load_cache(cache_path, rank=0)
|
|
|
|
# No further tuning should be performed.
|
|
runner, deserialized_tactic = tuner.choose_one(
|
|
"test_autotuner_tactic_configs", runners, tuning_config, [x, w])
|
|
assert best_tactic == deserialized_tactic, "Tactic should be the same after deserialization"
|
|
|
|
runner_0([x, w], tactic=deserialized_tactic)
|
|
|
|
|
|
def test_kernel_testing_single_context():
|
|
"""Test kernel testing with a single choose_one context"""
|
|
x, w = torch.randn(16, 64), torch.randn(64, 128)
|
|
runners = [GemmRunner()]
|
|
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
|
input_idx=0,
|
|
dim_idx=0,
|
|
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
|
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
|
|
|
tuner = AutoTuner.get()
|
|
tuner.clear_cache()
|
|
|
|
# First, do tuning to populate cache
|
|
with autotune():
|
|
runner, tactic = tuner.choose_one("test_kernel_testing_single", runners,
|
|
tuning_config, [x, w])
|
|
|
|
# Capture execution context
|
|
with tuner.capture() as all_tactics:
|
|
runner, tactic = tuner.choose_one("test_kernel_testing_single", runners,
|
|
tuning_config, [x, w])
|
|
reference_output = runner([x, w], tactic=tactic)
|
|
|
|
# Test all tactics
|
|
tested_tactics = []
|
|
for (runner, tactic), in all_tactics:
|
|
tested_tactics.append((runner, tactic))
|
|
with tuner.replay(((runner, tactic), )):
|
|
runner_ret, tactic_ret = tuner.choose_one(
|
|
"test_kernel_testing_single", runners, tuning_config, [x, w])
|
|
output = runner_ret([x, w], tactic=tactic_ret)
|
|
# Verify output matches reference
|
|
torch.testing.assert_close(output, reference_output)
|
|
assert runner == runner_ret and tactic == tactic_ret, \
|
|
f"Runner and tactic mismatch: expected ({runner, tactic}), got ({runner_ret, tactic_ret})"
|
|
|
|
# Should have tested 3 tactics ([-1, 0, 1])
|
|
assert len(tested_tactics) == len(GemmRunner().get_valid_tactics([x, w], OptimizationProfile([[]]))), \
|
|
f"Expected 3 tactics to be tested, got {len(tested_tactics)}"
|
|
|
|
|
|
class MultiContextRunner(TunableRunner):
|
|
|
|
def get_valid_tactics(self, inputs: List[FakeTensor],
|
|
profile: OptimizationProfile, **kwargs) -> List[int]:
|
|
gemm_idx = kwargs.get("gemm_idx", 0)
|
|
# Different gemm_idx have different number of tactics
|
|
if gemm_idx == 0:
|
|
return [0, 1]
|
|
else:
|
|
return [0, 1, 2]
|
|
|
|
def forward(self,
|
|
/,
|
|
inputs: List[torch.Tensor],
|
|
*,
|
|
tactic: int = -1,
|
|
**kwargs) -> torch.Tensor:
|
|
gemm_idx = kwargs.get("gemm_idx", 0)
|
|
# Analogous to CUTLASS MoE trtllm::fused_moe FC1
|
|
if gemm_idx == 0:
|
|
return [gemm_0, gemm_1][tactic](inputs[0], inputs[1])
|
|
# Analogous to CUTLASS MoE trtllm::fused_moe FC2
|
|
else:
|
|
return [gemm_0, gemm_1, gemm_fallback][tactic](inputs[1].T,
|
|
inputs[0].T)
|
|
|
|
|
|
def test_kernel_testing_multiple_contexts():
|
|
"""
|
|
Test kernel testing with multiple choose_one contexts
|
|
(e.g., CUTLASS MoE trtllm::fused_moe)
|
|
"""
|
|
|
|
x, w = torch.randn(16, 64), torch.randn(64, 128)
|
|
runners = [MultiContextRunner()]
|
|
tuning_config = TuningConfig()
|
|
|
|
tuner = AutoTuner.get()
|
|
tuner.clear_cache()
|
|
|
|
# First, do tuning to populate cache
|
|
with autotune():
|
|
runner, _ = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=0)
|
|
runner, _ = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=1)
|
|
|
|
# Capture execution context (captures both choose_one calls)
|
|
with tuner.capture() as all_tactics:
|
|
runner_0, tactic_0 = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=0)
|
|
runner_1, tactic_1 = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=1)
|
|
ref_output_0 = runner_0([x, w], tactic=tactic_0, gemm_idx=0)
|
|
ref_output_1 = runner_1([x, w], tactic=tactic_1, gemm_idx=1)
|
|
|
|
# Test all tactic combinations (cartesian product)
|
|
tested_tactics = []
|
|
for tactic in all_tactics:
|
|
tested_tactics.append(tactic)
|
|
# Each tactic is ((runner_0, tactic_0), (runner_1, tactic_1))
|
|
assert len(tactic) == 2, f"Expected 2 contexts, got {len(tactic)}"
|
|
|
|
with tuner.replay(tactic):
|
|
# Make the same calls in the same order
|
|
runner_0, tactic_0 = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=0)
|
|
runner_1, tactic_1 = tuner.choose_one("test_multi_context",
|
|
runners,
|
|
tuning_config, [x, w],
|
|
gemm_idx=1)
|
|
|
|
output_0 = runner_0([x, w], tactic=tactic_0, gemm_idx=0)
|
|
output_1 = runner_1([x, w], tactic=tactic_1, gemm_idx=1)
|
|
|
|
# Verify each context independently
|
|
# Since we're testing different tactics, outputs will differ
|
|
# Just verify they don't crash and have correct shapes
|
|
assert output_0.shape == ref_output_0.shape
|
|
assert output_1.shape == ref_output_1.shape
|
|
|
|
# Should have tested 2*3 = 6 combinations
|
|
num_tactics_for_gemm_idx = lambda gemm_idx: len(runners[
|
|
0].get_valid_tactics([x, w], OptimizationProfile(), gemm_idx=gemm_idx))
|
|
assert len(tested_tactics) == num_tactics_for_gemm_idx(0) * num_tactics_for_gemm_idx(1), \
|
|
f"Expected 6 tactic combinations (2*3), got {len(tested_tactics)}"
|
|
|
|
|
|
def test_kernel_testing_mismatched_ops():
|
|
"""
|
|
Correctly raise and capture the exception when captured context != operation performed
|
|
"""
|
|
x, w = torch.randn(16, 64), torch.randn(64, 128)
|
|
runners = [GemmRunner()]
|
|
tuning_config = TuningConfig()
|
|
|
|
tuner = AutoTuner.get()
|
|
tuner.clear_cache()
|
|
|
|
# Capture execution context for operation A
|
|
with tuner.capture() as all_tactics:
|
|
_ = tuner.choose_one("test_op_A", runners, tuning_config, [x, w])
|
|
|
|
# Try to test with operation B (should raise RuntimeError)
|
|
try:
|
|
for (runner, tactic), in all_tactics:
|
|
with tuner.replay(((runner, tactic), )):
|
|
# This should raise RuntimeError because custom_op doesn't match
|
|
_ = tuner.choose_one("test_op_B", runners, tuning_config,
|
|
[x, w])
|
|
assert False, "Expected RuntimeError for mismatched custom_op, but none was raised"
|
|
except RuntimeError as e:
|
|
# Verify the error message contains useful information
|
|
error_msg = str(e)
|
|
assert "Custom op mismatch" in error_msg, f"Expected 'Custom op mismatch' in error message, got: {error_msg}"
|
|
assert "test_op_A" in error_msg, f"Expected 'test_op_A' in error message, got: {error_msg}"
|
|
assert "test_op_B" in error_msg, f"Expected 'test_op_B' in error message, got: {error_msg}"
|
|
|
|
|
|
class DistributedGemmRunner(TunableRunner):
|
|
|
|
def __init__(self, prefer_tactics: List[int] = [0, 1]):
|
|
self.prefer_tactics = prefer_tactics
|
|
|
|
def get_valid_tactics(self, inputs, profile, **kwargs):
|
|
# Return all tactics so merge strategy can choose between them
|
|
return self.prefer_tactics
|
|
|
|
def forward(self, inputs, *, tactic=-1, **kwargs):
|
|
# tactic 0 is slower
|
|
if tactic % 2 == 0:
|
|
for _ in range(5):
|
|
inputs[0] @ inputs[1]
|
|
return inputs[0] @ inputs[1]
|
|
|
|
def unique_id(self):
|
|
return ()
|
|
|
|
|
|
def _distributed_worker_function(world_size, strategy):
|
|
"""Worker function to run on each MPI rank."""
|
|
rank = tensorrt_llm.mpi_rank()
|
|
mapping = Mapping(world_size=world_size,
|
|
rank=rank,
|
|
tp_size=world_size,
|
|
pp_size=1)
|
|
if mpi_disabled():
|
|
dist = TorchDist(mapping=mapping)
|
|
else:
|
|
dist = MPIDist(mapping=mapping)
|
|
|
|
tuner = AutoTuner.get()
|
|
tuner.clear_cache()
|
|
tuner.setup_distributed_state(mapping, dist)
|
|
|
|
x = torch.randn(16, 32, device='cuda')
|
|
w = torch.randn(32, 64, device='cuda')
|
|
inputs = [x, w]
|
|
|
|
if strategy == DistributedTuningStrategy.PARALLEL:
|
|
# All ranks get the same set of tactics
|
|
prefer_tactics = [0, 1, 2, 3]
|
|
else:
|
|
# Each rank prefers different tactics
|
|
prefer_tactics = [rank]
|
|
runner = DistributedGemmRunner(prefer_tactics=prefer_tactics)
|
|
config = TuningConfig(distributed_tuning_strategy=strategy)
|
|
|
|
if rank == 0:
|
|
temp_dir = tempfile.TemporaryDirectory()
|
|
# rank 0 should broadcast the cache path to all ranks
|
|
cache_path = os.path.join(temp_dir.name, "test_distributed_tuning.json")
|
|
dist.broadcast(cache_path, root=0)
|
|
else:
|
|
cache_path = dist.broadcast(None, root=0)
|
|
|
|
with autotune(cache_path=cache_path):
|
|
tuner.choose_one(custom_op=f"test_distributed_{strategy}",
|
|
runners=[runner],
|
|
tuning_config=config,
|
|
inputs=inputs)
|
|
|
|
# Check only one file is created in the cache path
|
|
assert len(os.listdir(os.path.dirname(
|
|
cache_path))) == 1, "Only one rank file should be created"
|
|
|
|
# Check cache for distributed tuning
|
|
AutoTuner.get().profiling_cache.clear()
|
|
AutoTuner.get().profiling_cache.load_cache(cache_path, rank)
|
|
|
|
selected_runner, best_tactic = tuner.choose_one(
|
|
custom_op=f"test_distributed_{strategy}",
|
|
runners=[runner],
|
|
tuning_config=config,
|
|
inputs=inputs)
|
|
|
|
if strategy == DistributedTuningStrategy.BROADCAST:
|
|
# All ranks should select tactic 0
|
|
assert best_tactic == 0
|
|
elif strategy == DistributedTuningStrategy.INDEPENDENT:
|
|
# Each rank should select the tactic it prefers
|
|
assert best_tactic == rank
|
|
elif strategy == DistributedTuningStrategy.MERGE:
|
|
# Because tactic 0 is slower, two ranks should always select tactic 1
|
|
assert best_tactic == 1
|
|
elif strategy == DistributedTuningStrategy.PARALLEL:
|
|
# Tactic 1 or 3 should be selected since they are faster.
|
|
# TODO: This might not cover the case that rank1 tunes nothing
|
|
assert best_tactic % 2 == 1
|
|
else:
|
|
assert False, f"Unknown strategy: {strategy}"
|
|
|
|
return True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason="Requires at least 2 GPUs for this test")
|
|
@pytest.mark.parametrize(
|
|
"strategy",
|
|
[
|
|
DistributedTuningStrategy.BROADCAST,
|
|
DistributedTuningStrategy.INDEPENDENT,
|
|
DistributedTuningStrategy.MERGE,
|
|
DistributedTuningStrategy.PARALLEL,
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
|
|
def test_autotuner_distributed_strategy(strategy, mpi_pool_executor):
|
|
world_size = 2
|
|
# Use MPIPoolExecutor to run distributed test
|
|
results = mpi_pool_executor.map(
|
|
_distributed_worker_function,
|
|
*zip(*[(
|
|
world_size,
|
|
strategy,
|
|
)] * world_size),
|
|
)
|
|
for r in results:
|
|
assert r is True
|