TensorRT-LLMs/tests/_torch/multi_gpu/test_linear.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

332 lines
11 KiB
Python

import pickle
import sys
import traceback
import cloudpickle
import pytest
import torch
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from torch import nn
import tensorrt_llm
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
cloudpickle.register_pickle_by_value(sys.modules[__name__])
MPI.pickle.__init__(
cloudpickle.dumps,
cloudpickle.loads,
pickle.HIGHEST_PROTOCOL,
)
def rms_norm(x: torch.Tensor, weight: torch.Tensor = None, eps: float = 1e-6):
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
if weight is not None:
y = y * weight
return y
def run_single_rank(tensor_parallel_size, single_rank_forward_func, input,
weights, hidden_size, dtype):
rank = tensorrt_llm.mpi_rank()
torch.cuda.set_device(rank)
try:
single_rank_forward_func(input, hidden_size, dtype,
tensor_parallel_size, rank, weights)
except Exception:
traceback.print_exc()
raise
return True
@torch.inference_mode
def mlp_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
l0 = Linear(
in_features=hidden_size,
out_features=4 * hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
),
)
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
l1 = Linear(in_features=4 * hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l1.load_weights([dict(weight=weights[1])])
l1.cuda()
mlp = torch.compile(lambda x: l1.forward(l0.forward(x)), fullgraph=True)
output = mlp(x)
# torch run
l0 = nn.Linear(in_features=hidden_size,
out_features=4 * hidden_size,
bias=False,
dtype=dtype)
l0.weight.data.copy_(weights[0])
l0.cuda()
l1 = nn.Linear(in_features=4 * hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
l1.weight.data.copy_(weights[1])
l1.cuda()
torch_output = l1.forward(l0.forward(x))
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, torch_output, rtol=0.05, atol=0.05)
@torch.inference_mode
def column_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True))
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
l0 = torch.compile(l0, fullgraph=True)
output = l0.forward(x)
# torch run
l0 = nn.Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
l0.weight.data.copy_(weights[0])
l0.cuda()
torch_output = l0.forward(x)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, torch_output)
@torch.inference_mode
def row_linear_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
xs = torch.chunk(x, 2, dim=-1)
l0 = torch.compile(l0, fullgraph=True)
output = l0.forward(xs[tensor_parallel_rank])
# torch run
l0 = nn.Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
l0.weight.data.copy_(weights[0])
l0.cuda()
torch_output = l0.forward(x)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, torch_output, rtol=0.05, atol=0.05)
@torch.inference_mode
def row_linear_norm_fusion_forward(x, hidden_size, dtype, tensor_parallel_size,
tensor_parallel_rank, weights):
x = x.cuda()
residual = torch.randn_like(x)
norm_weight = torch.randn((1, hidden_size), dtype=dtype, device="cuda")
eps = 1e-6
fusion_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
l0 = Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_mode=TensorParallelMode.ROW,
))
l0.load_weights([dict(weight=weights[0])])
l0.cuda()
xs = torch.chunk(x, 2, dim=-1)
l0 = torch.compile(l0, fullgraph=True)
final_output, inter_output = l0.forward(
xs[tensor_parallel_rank],
all_reduce_params=fusion_params,
)
# torch run
l0 = nn.Linear(in_features=hidden_size,
out_features=hidden_size,
bias=False,
dtype=dtype)
l0.weight.data.copy_(weights[0])
l0.cuda()
torch_output = l0.forward(x)
torch_inter_output = torch_output + residual
torch_final_output = rms_norm(torch_inter_output, norm_weight, eps)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(
inter_output,
torch_inter_output,
rtol=0.05,
atol=0.15,
)
torch.testing.assert_close(
final_output,
torch_final_output,
rtol=0.05,
atol=0.15,
)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='needs 2 GPUs to run this test')
def test_mlp():
torch.manual_seed(42)
seq_len = 2
hidden_size = 16
dtype = torch.bfloat16
tensor_parallel_size = 2
x = torch.randn((seq_len, hidden_size), dtype=dtype)
l0_weight = torch.randn((4 * hidden_size, hidden_size), dtype=dtype)
l1_weight = torch.randn((hidden_size, 4 * hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_single_rank,
*zip(*[(tensor_parallel_size, mlp_forward, x,
[l0_weight, l1_weight], hidden_size, dtype)] * 2))
for r in results:
assert r is True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='needs 2 GPUs to run this test')
@pytest.mark.parametrize("hidden_size", [128, 127],
ids=["balanced", "unbalanced"])
def test_column_linear(hidden_size):
torch.manual_seed(42)
seq_len = 10
dtype = torch.bfloat16
tensor_parallel_size = 2
x = torch.randn((seq_len, hidden_size), dtype=dtype)
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_single_rank,
*zip(*[(tensor_parallel_size, column_linear_forward, x, [l0_weight],
hidden_size, dtype)] * 2))
if hidden_size % 2 != 0:
with pytest.raises(AssertionError):
for r in results:
assert r is True
else:
for r in results:
assert r is True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='needs 2 GPUs to run this test')
@pytest.mark.parametrize("hidden_size", [16, 15],
ids=["balanced", "unbalanced"])
def test_row_linear(hidden_size):
torch.manual_seed(42)
seq_len = 2
dtype = torch.bfloat16
tensor_parallel_size = 2
x = torch.randn((seq_len, hidden_size), dtype=dtype)
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_single_rank,
*zip(*[(tensor_parallel_size, row_linear_forward, x, [l0_weight],
hidden_size, dtype)] * 2))
if hidden_size % 2 != 0:
with pytest.raises(AssertionError):
for r in results:
assert r is True
else:
for r in results:
assert r is True
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='needs 2 GPUs to run this test')
@pytest.mark.parametrize("seq_len", [2, 32], ids=lambda x: f"seqlen:{x}")
@pytest.mark.parametrize("hidden_size", [16, 256], ids=lambda x: f"hidden:{x}")
def test_row_linear_norm_fusion(seq_len, hidden_size):
torch.manual_seed(42)
dtype = torch.bfloat16
tensor_parallel_size = 2
x = torch.randn((seq_len, hidden_size), dtype=dtype)
l0_weight = torch.randn((hidden_size, hidden_size), dtype=dtype)
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
results = executor.map(
run_single_rank,
*zip(*[(tensor_parallel_size, row_linear_norm_fusion_forward, x,
[l0_weight], hidden_size, dtype)] * 2))
for r in results:
assert r is True
if __name__ == '__main__':
test_column_linear(128)
test_column_linear(127)
test_row_linear(16)
test_row_linear(15)
test_mlp()
test_row_linear_norm_fusion(32, 256)
test_row_linear_norm_fusion(32, 16)
test_row_linear_norm_fusion(2, 16)
test_row_linear_norm_fusion(2, 256)