mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
332 lines
11 KiB
Python
332 lines
11 KiB
Python
import pickle
|
|
import sys
|
|
import traceback
|
|
|
|
import cloudpickle
|
|
import pytest
|
|
import torch
|
|
from mpi4py import MPI
|
|
from mpi4py.futures import MPIPoolExecutor
|
|
from torch import nn
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
|
from tensorrt_llm._torch.modules.linear import Linear
|
|
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
|
|
|
|
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
|
MPI.pickle.__init__(
|
|
cloudpickle.dumps,
|
|
cloudpickle.loads,
|
|
pickle.HIGHEST_PROTOCOL,
|
|
)
|
|
|
|
|
|
def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
|
|
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
|
if weight is not None:
|
|
y = y * weight
|
|
return y
|
|
|
|
|
|
def run_single_rank(tensor_parallel_size, single_rank_forward_func, input,
|
|
weights, hidden_size, dtype):
|
|
rank = tensorrt_llm.mpi_rank()
|
|
torch.cuda.set_device(rank)
|
|
try:
|
|
single_rank_forward_func(input, hidden_size, dtype,
|
|
tensor_parallel_size, rank, weights)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
raise
|
|
return True
|
|
|
|
|
|
@torch.inference_mode
|
|
def mlp_forward(x, hidden_size, dtype, tensor_parallel_size,
|
|
tensor_parallel_rank, weights):
|
|
x = x.cuda()
|
|
|
|
l0 = Linear(
|
|
in_features=hidden_size,
|
|
out_features=4 * hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
),
|
|
)
|
|
l0.load_weights([dict(weight=weights[0])])
|
|
l0.cuda()
|
|
l1 = Linear(in_features=4 * hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
))
|
|
l1.load_weights([dict(weight=weights[1])])
|
|
l1.cuda()
|
|
|
|
mlp = torch.compile(lambda x: l1.forward(l0.forward(x)), fullgraph=True)
|
|
output = mlp(x)
|
|
|
|
# torch run
|
|
l0 = nn.Linear(in_features=hidden_size,
|
|
out_features=4 * hidden_size,
|
|
bias=False,
|
|
dtype=dtype)
|
|
l0.weight.data.copy_(weights[0])
|
|
l0.cuda()
|
|
|
|
l1 = nn.Linear(in_features=4 * hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype)
|
|
l1.weight.data.copy_(weights[1])
|
|
l1.cuda()
|
|
|
|
torch_output = l1.forward(l0.forward(x))
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output, rtol=0.05, atol=0.05)
|
|
|
|
|
|
@torch.inference_mode
|
|
def column_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
|
|
tensor_parallel_rank, weights):
|
|
|
|
x = x.cuda()
|
|
l0 = Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
gather_output=True))
|
|
l0.load_weights([dict(weight=weights[0])])
|
|
l0.cuda()
|
|
|
|
l0 = torch.compile(l0, fullgraph=True)
|
|
output = l0.forward(x)
|
|
|
|
# torch run
|
|
l0 = nn.Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype)
|
|
l0.weight.data.copy_(weights[0])
|
|
l0.cuda()
|
|
|
|
torch_output = l0.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output)
|
|
|
|
|
|
@torch.inference_mode
|
|
def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
|
|
tensor_parallel_rank, weights):
|
|
|
|
x = x.cuda()
|
|
l0 = Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
))
|
|
l0.load_weights([dict(weight=weights[0])])
|
|
l0.cuda()
|
|
|
|
xs = torch.chunk(x, 2, dim=-1)
|
|
l0 = torch.compile(l0, fullgraph=True)
|
|
output = l0.forward(xs[tensor_parallel_rank])
|
|
|
|
# torch run
|
|
l0 = nn.Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype)
|
|
l0.weight.data.copy_(weights[0])
|
|
l0.cuda()
|
|
|
|
torch_output = l0.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output, rtol=0.05, atol=0.05)
|
|
|
|
|
|
@torch.inference_mode
|
|
def row_linear_norm_fusion_forward(x, hidden_size, dtype, tensor_parallel_size,
|
|
tensor_parallel_rank, weights):
|
|
|
|
x = x.cuda()
|
|
residual = torch.randn_like(x)
|
|
norm_weight = torch.randn((1, hidden_size), dtype=dtype, device="cuda")
|
|
eps = 1e-6
|
|
fusion_params = AllReduceParams(
|
|
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
|
|
residual=residual,
|
|
norm_weight=norm_weight,
|
|
eps=eps,
|
|
)
|
|
|
|
l0 = Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
))
|
|
l0.load_weights([dict(weight=weights[0])])
|
|
l0.cuda()
|
|
|
|
xs = torch.chunk(x, 2, dim=-1)
|
|
l0 = torch.compile(l0, fullgraph=True)
|
|
final_output, inter_output = l0.forward(
|
|
xs[tensor_parallel_rank],
|
|
all_reduce_params=fusion_params,
|
|
)
|
|
|
|
# torch run
|
|
l0 = nn.Linear(in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
bias=False,
|
|
dtype=dtype)
|
|
l0.weight.data.copy_(weights[0])
|
|
l0.cuda()
|
|
|
|
torch_output = l0.forward(x)
|
|
torch_inter_output = torch_output + residual
|
|
torch_final_output = rms_norm(torch_inter_output, norm_weight, eps)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(
|
|
inter_output,
|
|
torch_inter_output,
|
|
rtol=0.05,
|
|
atol=0.15,
|
|
)
|
|
torch.testing.assert_close(
|
|
final_output,
|
|
torch_final_output,
|
|
rtol=0.05,
|
|
atol=0.15,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
def test_mlp():
|
|
torch.manual_seed(42)
|
|
seq_len = 2
|
|
hidden_size = 16
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
x = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
l0_weight = torch.randn((4 * hidden_size, hidden_size), dtype=dtype)
|
|
l1_weight = torch.randn((hidden_size, 4 * hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, mlp_forward, x,
|
|
[l0_weight, l1_weight], hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("hidden_size", [128, 127],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_column_linear(hidden_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 10
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
x = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, column_linear_forward, x, [l0_weight],
|
|
hidden_size, dtype)] * 2))
|
|
if hidden_size % 2 != 0:
|
|
with pytest.raises(AssertionError):
|
|
for r in results:
|
|
assert r is True
|
|
else:
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("hidden_size", [16, 15],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_row_linear(hidden_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 2
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
x = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, row_linear_forward, x, [l0_weight],
|
|
hidden_size, dtype)] * 2))
|
|
if hidden_size % 2 != 0:
|
|
with pytest.raises(AssertionError):
|
|
for r in results:
|
|
assert r is True
|
|
else:
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("seq_len", [2, 32], ids=lambda x: f"seqlen:{x}")
|
|
@pytest.mark.parametrize("hidden_size", [16, 256], ids=lambda x: f"hidden:{x}")
|
|
def test_row_linear_norm_fusion(seq_len, hidden_size):
|
|
torch.manual_seed(42)
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
x = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, row_linear_norm_fusion_forward, x,
|
|
[l0_weight], hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_column_linear(128)
|
|
test_column_linear(127)
|
|
test_row_linear(16)
|
|
test_row_linear(15)
|
|
test_mlp()
|
|
test_row_linear_norm_fusion(32, 256)
|
|
test_row_linear_norm_fusion(32, 16)
|
|
test_row_linear_norm_fusion(2, 16)
|
|
test_row_linear_norm_fusion(2, 256)
|