mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
62 lines
1.8 KiB
Python
62 lines
1.8 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from plugin_lib import LookUpPlugin
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Tensor
|
|
from tensorrt_llm._utils import torch_dtype_to_str, torch_dtype_to_trt
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# meta data
|
|
batch_size = 10
|
|
vocab_size = 1000
|
|
n_embed = 1024
|
|
|
|
# test data
|
|
## input index
|
|
index_shape = (batch_size, )
|
|
index_data = torch.randint(0, vocab_size, index_shape,
|
|
dtype=torch.int32).cuda()
|
|
|
|
def test(dtype):
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = True
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
x = Tensor(
|
|
name="x",
|
|
shape=index_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
|
|
)
|
|
y = Tensor(name="y",
|
|
shape=(vocab_size, n_embed),
|
|
dtype=torch_dtype_to_trt(dtype))
|
|
|
|
def lookup(x, y):
|
|
lookup_plugin = LookUpPlugin(False, True)
|
|
return lookup_plugin(x, y)
|
|
|
|
output = lookup(x, y)
|
|
|
|
output.mark_output("output", torch_dtype_to_str(torch.float32))
|
|
|
|
builder_config = builder.create_builder_config("float32")
|
|
engine = builder.build_engine(network, builder_config)
|
|
assert engine is not None
|
|
|
|
output_dir = Path("tmp") / torch_dtype_to_str(dtype)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
engine_path = output_dir / "lookup.engine"
|
|
config_path = output_dir / "config.json"
|
|
|
|
with engine_path.open("wb") as f:
|
|
f.write(engine)
|
|
builder.save_config(builder_config, str(config_path))
|
|
|
|
test(torch.bfloat16)
|
|
test(torch.float16)
|
|
test(torch.float32)
|