mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
276 lines
8.3 KiB
Python
276 lines
8.3 KiB
Python
import pickle
|
|
import sys
|
|
import traceback
|
|
|
|
import cloudpickle
|
|
import pytest
|
|
import torch
|
|
from mpi4py import MPI
|
|
from mpi4py.futures import MPIPoolExecutor
|
|
from torch import nn
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.distributed import ParallelConfig, TensorParallelMode
|
|
from tensorrt_llm._torch.modules.embedding import Embedding, LMHead
|
|
|
|
cloudpickle.register_pickle_by_value(sys.modules[__name__])
|
|
MPI.pickle.__init__(
|
|
cloudpickle.dumps,
|
|
cloudpickle.loads,
|
|
pickle.HIGHEST_PROTOCOL,
|
|
)
|
|
|
|
|
|
def run_single_rank(tensor_parallel_size, single_rank_forward_func, input,
|
|
weights, vocab_size, hidden_size, dtype):
|
|
rank = tensorrt_llm.mpi_rank()
|
|
torch.cuda.set_device(rank)
|
|
try:
|
|
single_rank_forward_func(input, vocab_size, hidden_size, dtype,
|
|
tensor_parallel_size, rank, weights)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
raise
|
|
return True
|
|
|
|
|
|
@torch.inference_mode
|
|
def column_embedding_forward(x, vocab_size, hidden_size, dtype,
|
|
tensor_parallel_size, tensor_parallel_rank,
|
|
weight):
|
|
|
|
x = x.cuda()
|
|
embedding = Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN),
|
|
)
|
|
embedding.load_weights([dict(weight=weight)])
|
|
embedding.cuda()
|
|
|
|
output = embedding.forward(x)
|
|
|
|
# torch run
|
|
embedding = nn.Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
)
|
|
embedding.weight.data.copy_(weight)
|
|
embedding.cuda()
|
|
|
|
torch_output = embedding.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output)
|
|
|
|
|
|
@torch.inference_mode
|
|
def row_embedding_forward(x, vocab_size, hidden_size, dtype,
|
|
tensor_parallel_size, tensor_parallel_rank, weight):
|
|
|
|
x = x.cuda()
|
|
embedding = Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
gather_output=True),
|
|
)
|
|
embedding.load_weights([dict(weight=weight)])
|
|
embedding.cuda()
|
|
|
|
output = embedding.forward(x)
|
|
|
|
# torch run
|
|
embedding = nn.Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
)
|
|
embedding.weight.data.copy_(weight)
|
|
embedding.cuda()
|
|
|
|
torch_output = embedding.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output)
|
|
|
|
|
|
@torch.inference_mode
|
|
def column_lm_head_forward(x, vocab_size, hidden_size, dtype,
|
|
tensor_parallel_size, tensor_parallel_rank, weight):
|
|
|
|
x = x.cuda()
|
|
lm_head = LMHead(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
gather_output=True),
|
|
)
|
|
lm_head.load_weights([dict(weight=weight)])
|
|
lm_head.cuda()
|
|
|
|
output = lm_head.forward(x)
|
|
|
|
# torch run
|
|
lm_head = nn.Linear(
|
|
in_features=hidden_size,
|
|
out_features=vocab_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
)
|
|
lm_head.weight.data.copy_(weight)
|
|
lm_head.cuda()
|
|
|
|
torch_output = lm_head.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output)
|
|
|
|
|
|
@torch.inference_mode
|
|
def row_lm_head_forward(x, vocab_size, hidden_size, dtype, tensor_parallel_size,
|
|
tensor_parallel_rank, weight):
|
|
|
|
x = x.cuda()
|
|
lm_head = LMHead(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
tensor_parallel_rank=tensor_parallel_rank,
|
|
tensor_parallel_mode=TensorParallelMode.ROW),
|
|
)
|
|
lm_head.load_weights([dict(weight=weight)])
|
|
lm_head.cuda()
|
|
|
|
xs = torch.chunk(x, 2, dim=-1)
|
|
output = lm_head.forward(xs[tensor_parallel_rank])
|
|
|
|
# torch run
|
|
lm_head = nn.Linear(
|
|
in_features=hidden_size,
|
|
out_features=vocab_size,
|
|
bias=False,
|
|
dtype=dtype,
|
|
)
|
|
lm_head.weight.data.copy_(weight)
|
|
lm_head.cuda()
|
|
|
|
torch_output = lm_head.forward(x)
|
|
|
|
# compare
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(output, torch_output, rtol=0.05, atol=0.05)
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("vocab_size", [128, 127],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_column_embedding(vocab_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 10
|
|
hidden_size = 16
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
input = torch.randint(0, vocab_size, (seq_len, ))
|
|
weight = torch.randn((vocab_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, column_embedding_forward, input,
|
|
weight, vocab_size, hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("hidden_size", [16, 15],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_row_embedding(hidden_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 2
|
|
vocab_size = 128
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
input = torch.randint(0, vocab_size, (seq_len, ))
|
|
weight = torch.randn((vocab_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, row_embedding_forward, input, weight,
|
|
vocab_size, hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("vocab_size", [128, 127],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_column_lm_head(vocab_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 10
|
|
hidden_size = 16
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
input = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
weight = torch.randn((vocab_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, column_lm_head_forward, input, weight,
|
|
vocab_size, hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
|
reason='needs 2 GPUs to run this test')
|
|
@pytest.mark.parametrize("hidden_size", [16, 15],
|
|
ids=["balanced", "unbalanced"])
|
|
def test_row_lm_head(hidden_size):
|
|
torch.manual_seed(42)
|
|
seq_len = 2
|
|
vocab_size = 128
|
|
dtype = torch.bfloat16
|
|
tensor_parallel_size = 2
|
|
input = torch.randn((seq_len, hidden_size), dtype=dtype)
|
|
weight = torch.randn((vocab_size, hidden_size), dtype=dtype)
|
|
with MPIPoolExecutor(max_workers=tensor_parallel_size) as executor:
|
|
results = executor.map(
|
|
run_single_rank,
|
|
*zip(*[(tensor_parallel_size, row_lm_head_forward, input, weight,
|
|
vocab_size, hidden_size, dtype)] * 2))
|
|
for r in results:
|
|
assert r is True
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_column_embedding(128)
|
|
test_column_embedding(127)
|
|
test_row_embedding(16)
|
|
test_row_embedding(15)
|
|
test_column_lm_head(128)
|
|
test_column_lm_head(127)
|
|
test_row_lm_head(16)
|
|
test_row_lm_head(15)
|