mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
2227 lines
97 KiB
Python
2227 lines
97 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import unittest
|
|
from itertools import product
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
# isort: off
|
|
import torch
|
|
import tensorrt as trt
|
|
# isort: on
|
|
from parameterized import parameterized
|
|
from polygraphy.backend.trt import (CreateConfig, EngineFromNetwork, Profile,
|
|
TrtRunner)
|
|
from transformers.models.bloom.modeling_bloom import build_alibi_tensor
|
|
from transformers.models.llama.modeling_llama import (LlamaConfig, LlamaMLP,
|
|
LlamaRMSNorm)
|
|
from utils.torch_ref import (attention_qkvpacked_ref, group_rms_norm_ref,
|
|
mamba2_ref, mamba_ref, recurrent_ref)
|
|
from utils.util import skip_fp8_pre_ada, unittest_name_func
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Tensor
|
|
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
|
|
from tensorrt_llm.layers import (AttentionParams, KeyValueCacheParams,
|
|
PositionEmbeddingType)
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
|
|
class TestLayer(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
def test_group_norm_float32(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(2, 6, 3, 3)
|
|
m = torch.nn.GroupNorm(3, 6)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.GroupNorm(3, 6)
|
|
|
|
gm.weight.value = m.weight.detach().cpu().numpy()
|
|
gm.bias.value = m.bias.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
def test_layer_norm_float32(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(2, 5, 10, 10)
|
|
m = torch.nn.LayerNorm([5, 10, 10])
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.LayerNorm([5, 10, 10])
|
|
|
|
gm.weight.value = m.weight.detach().cpu().numpy()
|
|
gm.bias.value = m.bias.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
def test_rms_norm_float32(self):
|
|
# test data
|
|
test_shape = [2, 5, 10, 16]
|
|
dtype = 'float32'
|
|
x_data = torch.randn(*test_shape)
|
|
m = LlamaRMSNorm(test_shape[-1]) # LlamaRMSNorm only supports last dim
|
|
with torch.no_grad():
|
|
m.weight.copy_(torch.rand([test_shape[-1]]))
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.RmsNorm(test_shape[-1])
|
|
|
|
gm.weight.value = m.weight.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
def test_group_rms_norm_float32(self):
|
|
# test data
|
|
test_shape = [2, 5, 10, 16]
|
|
num_groups = 4
|
|
dtype = 'float32'
|
|
device = 'cuda'
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
x_data = torch.randn(*test_shape, dtype=torch_dtype, device=device)
|
|
weight_data = torch.randn(test_shape[-1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.RmsNorm(test_shape[-1], num_groups=4)
|
|
|
|
gm.weight.value = weight_data.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.cpu().numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = group_rms_norm_ref(x_data,
|
|
weight_data,
|
|
group_size=test_shape[-1] // num_groups)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
@parameterized.expand([[tensorrt_llm.layers.GatedMLP, 'float32'],
|
|
[tensorrt_llm.layers.GatedMLP, 'fp8'],
|
|
[tensorrt_llm.layers.FusedGatedMLP, 'float32']],
|
|
name_func=unittest_name_func)
|
|
def test_gated_mlp(self, ClsMLP, qformat):
|
|
|
|
skip_fp8_pre_ada(qformat == 'fp8')
|
|
if qformat == 'fp8':
|
|
pytest.xfail("FIXME: test is broken since 0a1990b69")
|
|
|
|
# test data
|
|
d_h = 8
|
|
ffn_h = 20
|
|
test_shape = [2, 3, 5, d_h]
|
|
dtype = 'float32'
|
|
torch.random.manual_seed(0)
|
|
# need rand for 'normalized' values
|
|
x_data = torch.randn(*test_shape)
|
|
fc = torch.empty(ffn_h, d_h)
|
|
torch.nn.init.xavier_uniform_(fc)
|
|
gate = torch.empty(ffn_h, d_h)
|
|
torch.nn.init.xavier_uniform_(gate)
|
|
proj = torch.empty(d_h, ffn_h)
|
|
torch.nn.init.xavier_uniform_(proj)
|
|
config = LlamaConfig(hidden_size=d_h,
|
|
intermediate_size=ffn_h,
|
|
hidden_act='silu')
|
|
m = LlamaMLP(config)
|
|
# Need torch.no_grad() to update the weights of torch.nn.Linear weights
|
|
with torch.no_grad():
|
|
m.gate_proj.weight.copy_(fc)
|
|
m.up_proj.weight.copy_(gate)
|
|
m.down_proj.weight.copy_(proj)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
quant_mode = QuantMode(0)
|
|
if qformat == 'fp8':
|
|
quant_mode = quant_mode.set_fp8_qdq()
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = ClsMLP(d_h,
|
|
ffn_h,
|
|
hidden_act='silu',
|
|
bias=False,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype),
|
|
quant_mode=quant_mode)
|
|
|
|
# TensorRT-LLM's Linear uses Parameter class which as a 'value' setter
|
|
if isinstance(gm, tensorrt_llm.layers.FusedGatedMLP):
|
|
fused_fc = torch.cat([gate, fc], dim=0).cpu().numpy()
|
|
gm.fused_fc.weight.value = fused_fc
|
|
else:
|
|
gm.fc.weight.value = fc.cpu().numpy()
|
|
gm.gate.weight.value = gate.cpu().numpy()
|
|
gm.proj.weight.value = proj.cpu().numpy()
|
|
if quant_mode.has_fp8_qdq():
|
|
gm.proj.weights_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
gm.proj.activation_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
if isinstance(gm, tensorrt_llm.layers.FusedGatedMLP):
|
|
gm.fused_fc.weights_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
gm.fused_fc.activation_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
else:
|
|
gm.fc.weights_scaling_factor.value = np.array(
|
|
[1.42], dtype=np.float32)
|
|
gm.gate.weights_scaling_factor.value = np.array(
|
|
[1.42], dtype=np.float32)
|
|
gm.fc.activation_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
gm.gate.activation_scaling_factor.value = np.array(
|
|
[0.42], dtype=np.float32)
|
|
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork(
|
|
(builder.trt_builder, net.trt_network),
|
|
CreateConfig(fp8=(qformat == 'fp8'), precision_constraints="obey"))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
kwargs = {
|
|
'atol': 0.2,
|
|
'rtol': 0.03
|
|
} if qformat == 'fp8' else {
|
|
'atol': 1e-5
|
|
}
|
|
np.testing.assert_allclose(ref.cpu().numpy(), outputs['output'],
|
|
**kwargs)
|
|
|
|
@parameterized.expand(
|
|
[["float32", False], ["float32", True], ["float16", False],
|
|
["float16", True], ["float16", True, "float32"], ["bfloat16", False],
|
|
["bfloat16", True], ["bfloat16", True, "float32"],
|
|
["float32", True, None, 4], ["float16", True, None, 4],
|
|
["bfloat16", True, None, 4], ["float16", True, "float32", 4],
|
|
["bfloat16", True, "float32", 4], ["float32", True, None, 0, 10],
|
|
["float16", True, None, 0, 10], ["bfloat16", True, None, 0, 10],
|
|
["float16", True, "float32", 0, 10],
|
|
["bfloat16", True, "float32", 0, 10], ["float32", True, None, 4, 10],
|
|
["float16", True, None, 4, 10], ["bfloat16", True, None, 4, 10],
|
|
["float16", True, "float32", 4, 10],
|
|
["bfloat16", True, "float32", 4, 10]],
|
|
name_func=unittest_name_func)
|
|
def test_linear(self,
|
|
dtype,
|
|
use_plugin,
|
|
output_dtype=None,
|
|
pad_lda=0,
|
|
pad_ldc=0):
|
|
if output_dtype is None:
|
|
output_dtype = dtype
|
|
|
|
# test data
|
|
torch.manual_seed(0)
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
x_data = torch.randn(128, 20 + pad_lda, dtype=torch_dtype)
|
|
m = torch.nn.Linear(20, 30, bias=(pad_ldc == 0), dtype=torch.float32)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
if use_plugin:
|
|
net.plugin_config.gemm_plugin = dtype
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.Linear(20,
|
|
30,
|
|
bias=(pad_ldc == 0),
|
|
dtype=dtype,
|
|
pad_lda=pad_lda,
|
|
pad_ldc=pad_ldc)
|
|
|
|
gm.weight.value = torch_to_numpy(
|
|
m.weight.to(torch_dtype).detach().cpu())
|
|
if pad_ldc == 0:
|
|
gm.bias.value = torch_to_numpy(
|
|
m.bias.to(torch_dtype).detach().cpu())
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
output.dtype = tensorrt_llm.str_dtype_to_trt(output_dtype)
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork(
|
|
(builder.trt_builder, net.trt_network),
|
|
CreateConfig(fp16=dtype == "float16",
|
|
bf16=dtype == "bfloat16",
|
|
precision_constraints="obey"))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data})
|
|
|
|
if pad_ldc:
|
|
outputs['output'] = torch.split(outputs['output'], [30, pad_ldc],
|
|
dim=-1)[0]
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data[:, 0:20].to(torch.float32)).to(
|
|
str_dtype_to_torch(output_dtype))
|
|
|
|
# The absolute tolerance for bfloat16 is increased marginally because
|
|
# a single value (out of 4000) breaks tolerance on a 4090 linux/windows.
|
|
atols = {"float32": 1e-6, "float16": 1e-2, "bfloat16": 1.6e-2}
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(),
|
|
outputs['output'].to(torch.float32).numpy(),
|
|
atol=atols[dtype])
|
|
|
|
@parameterized.expand([
|
|
["float32", False],
|
|
["float32", True],
|
|
["float16", False],
|
|
["float16", True],
|
|
["float16", True, "float32"],
|
|
["bfloat16", False],
|
|
["bfloat16", True],
|
|
["bfloat16", True, "float32"],
|
|
],
|
|
name_func=unittest_name_func)
|
|
def test_grouped_linear(self, dtype, use_plugin, output_dtype=None):
|
|
if output_dtype is None:
|
|
output_dtype = dtype
|
|
|
|
# test data
|
|
batch = 128
|
|
in_features = 20
|
|
out_features = 30
|
|
num_blocks = 5
|
|
torch.manual_seed(0)
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
x_data = torch.randn(batch, in_features, dtype=torch_dtype)
|
|
|
|
linear_weight = torch.randn(
|
|
[num_blocks, in_features // num_blocks, out_features // num_blocks],
|
|
dtype=str_dtype_to_torch(dtype))
|
|
linear_bias = torch.randn([num_blocks, out_features // num_blocks],
|
|
dtype=str_dtype_to_torch(dtype))
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
if use_plugin:
|
|
net.plugin_config.gemm_plugin = dtype
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.GroupedLinear(in_features,
|
|
out_features,
|
|
num_blocks,
|
|
dtype=dtype)
|
|
gm.weight.value = torch_to_numpy(
|
|
linear_weight.to(torch_dtype).detach().cpu())
|
|
gm.bias.value = torch_to_numpy(
|
|
linear_bias.to(torch_dtype).detach().cpu())
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
output.dtype = tensorrt_llm.str_dtype_to_trt(output_dtype)
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork(
|
|
(builder.trt_builder, net.trt_network),
|
|
CreateConfig(fp16=dtype == "float16",
|
|
bf16=dtype == "bfloat16",
|
|
precision_constraints="obey"))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
x = x_data.view(batch, num_blocks, in_features // num_blocks)
|
|
y = torch.einsum("... h i, h i j -> ... h j", x,
|
|
linear_weight) + linear_bias
|
|
ref = y.reshape(batch, out_features)
|
|
|
|
# The absolute tolerance for bfloat16 is increased marginally because
|
|
# a single value (out of 4000) breaks tolerance on a 4090 linux/windows.
|
|
atols = {"float32": 1e-6, "float16": 1e-2, "bfloat16": 1.6e-2}
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(),
|
|
outputs['output'].to(torch.float32).numpy(),
|
|
atol=atols[dtype])
|
|
|
|
@parameterized.expand(list(product([True, False])),
|
|
name_func=unittest_name_func)
|
|
def test_prompt_tuning_embedding(self, remove_padding):
|
|
torch.random.manual_seed(0)
|
|
dtype = "bfloat16"
|
|
trt_dtype = tensorrt_llm.str_dtype_to_trt(dtype)
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
embedding_dim = 64
|
|
batch_size = 8
|
|
seq_len = 12
|
|
vocab_size = 100
|
|
num_embeddings = 128
|
|
num_tasks = 3
|
|
task_vocab_size = 30
|
|
|
|
embeddings = torch.randn((num_embeddings, embedding_dim),
|
|
dtype=torch_dtype)
|
|
prompt_embedding = torch.randn(
|
|
(num_tasks * task_vocab_size, embedding_dim), dtype=torch_dtype)
|
|
ids = torch.randint(0,
|
|
vocab_size, (batch_size, seq_len),
|
|
dtype=torch.int32)
|
|
request_tasks = torch.randint(0,
|
|
num_tasks, (batch_size, ),
|
|
dtype=torch.int32)
|
|
request_tasks = request_tasks.unsqueeze(-1).expand(*ids.shape)
|
|
v_ids = torch.randint(vocab_size,
|
|
vocab_size + task_vocab_size,
|
|
(batch_size, seq_len),
|
|
dtype=torch.int32)
|
|
mask = torch.bernoulli(torch.full((batch_size, seq_len),
|
|
0.5)).to(torch.int32)
|
|
ids = ids * mask + v_ids * (1 - mask)
|
|
|
|
if remove_padding:
|
|
input_ids = ids.flatten().unsqueeze(0)
|
|
request_tasks = request_tasks.flatten().unsqueeze(0)
|
|
else:
|
|
input_ids = ids
|
|
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
|
|
with tensorrt_llm.net_guard(net):
|
|
ids_tensor = Tensor(name='ids',
|
|
shape=[1, -1] if remove_padding else [-1, -1],
|
|
dtype=trt.int32)
|
|
prompt_embedding_tensor = Tensor(name='prompt_embedding',
|
|
shape=[-1, embedding_dim],
|
|
dtype=trt_dtype)
|
|
request_tasks_tensor = Tensor(name='request_tasks',
|
|
shape=[-1, -1],
|
|
dtype=trt.int32)
|
|
task_vocab_size_tensor = Tensor(name='task_vocab_size',
|
|
shape=(1, ),
|
|
dtype=trt.int32)
|
|
|
|
embedding = tensorrt_llm.layers.PromptTuningEmbedding(
|
|
num_embeddings, embedding_dim, vocab_size, trt_dtype)
|
|
embedding.weight.value = torch_to_numpy(embeddings.detach().cpu())
|
|
|
|
output = embedding(ids_tensor, prompt_embedding_tensor,
|
|
request_tasks_tensor, task_vocab_size_tensor)
|
|
net._mark_output(output, "output", dtype=trt_dtype)
|
|
|
|
profile = (Profile().add(
|
|
"ids", (1, 1), input_ids.shape, input_ids.shape).add(
|
|
"prompt_embedding", (1, embedding_dim), prompt_embedding.shape,
|
|
prompt_embedding.shape).add("request_tasks", (1, 1),
|
|
input_ids.shape, input_ids.shape))
|
|
|
|
build_engine = EngineFromNetwork(
|
|
(builder.trt_builder, net.trt_network),
|
|
config=CreateConfig(bf16=(dtype == "bfloat16"),
|
|
fp16=(dtype == "float16"),
|
|
precision_constraints="obey",
|
|
profiles=[profile]))
|
|
assert build_engine is not None
|
|
with TrtRunner(build_engine) as runner:
|
|
output = runner.infer(
|
|
feed_dict={
|
|
'ids':
|
|
input_ids,
|
|
'prompt_embedding':
|
|
prompt_embedding,
|
|
'request_tasks':
|
|
request_tasks,
|
|
'task_vocab_size':
|
|
torch.tensor([task_vocab_size], dtype=torch.int32),
|
|
})['output']
|
|
|
|
output = output.to(torch.float32)
|
|
embeddings = embeddings.to(torch.float32)
|
|
prompt_embedding = prompt_embedding.view(
|
|
(num_tasks, task_vocab_size, embedding_dim)).to(torch.float32)
|
|
|
|
# use loops for clarity, even if it's non-optimal
|
|
for b in range(input_ids.shape[0]):
|
|
for s in range(input_ids.shape[1]):
|
|
token_id = input_ids[b][s]
|
|
if token_id < vocab_size:
|
|
np.testing.assert_allclose(output[b][s],
|
|
embeddings[token_id])
|
|
else:
|
|
offset_token_id = token_id - vocab_size
|
|
task = request_tasks[b][s]
|
|
np.testing.assert_allclose(
|
|
output[b][s], prompt_embedding[task][offset_token_id])
|
|
|
|
def test_conv2d_float32(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(20, 16, 50, 100)
|
|
m = torch.nn.Conv2d(16,
|
|
33, (3, 5),
|
|
stride=(2, 1),
|
|
padding=(4, 2),
|
|
dilation=(3, 1))
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.Conv2d(16,
|
|
33, (3, 5),
|
|
stride=(2, 1),
|
|
padding=(4, 2),
|
|
dilation=(3, 1))
|
|
|
|
gm.weight.value = m.weight.detach().cpu().numpy()
|
|
gm.bias.value = m.bias.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-5)
|
|
|
|
def test_conv_transpose2d_float32(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(20, 16, 50, 100)
|
|
m = torch.nn.ConvTranspose2d(16,
|
|
33, (3, 5),
|
|
stride=(2, 1),
|
|
padding=(4, 2))
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
gm = tensorrt_llm.layers.ConvTranspose2d(16,
|
|
33, (3, 5),
|
|
stride=(2, 1),
|
|
padding=(4, 2),
|
|
dilation=(3, 1))
|
|
|
|
gm.weight.value = m.weight.detach().cpu().numpy()
|
|
gm.bias.value = m.bias.detach().cpu().numpy()
|
|
output = gm.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-05)
|
|
|
|
def test_avg_pooling_2d_float32(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(2, 16, 50, 32)
|
|
m = torch.nn.AvgPool2d((3, 2), stride=(2, 1))
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
ap2d = tensorrt_llm.layers.AvgPool2d((3, 2), stride=(2, 1))
|
|
output = ap2d.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
@parameterized.expand([("bfloat16", "float32"), ("float32", "bfloat16")],
|
|
name_func=unittest_name_func)
|
|
def test_cast_bf16(self, from_dtype, to_dtype):
|
|
|
|
torch_from_dtype = str_dtype_to_torch(from_dtype)
|
|
torch_to_dtype = str_dtype_to_torch(to_dtype)
|
|
x_data = torch.randn(2, 2, 3, 6, dtype=torch_from_dtype)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(from_dtype))
|
|
|
|
cast = tensorrt_llm.layers.Cast(to_dtype)
|
|
output = cast.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork(
|
|
(builder.trt_builder, net.trt_network),
|
|
config=CreateConfig(bf16=True, precision_constraints="obey"))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data})
|
|
|
|
# pytorch run
|
|
ref = x_data.to(torch_to_dtype).to(torch.float32)
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'].to(torch.float32),
|
|
atol=0)
|
|
|
|
def test_cast(self):
|
|
dtype = 'float16'
|
|
x_data = torch.randn(2, 2, 3, 6, dtype=torch.float16)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
cast = tensorrt_llm.layers.Cast('float32')
|
|
output = cast.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
ref = x_data.to(torch.float32)
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
def test_mish(self):
|
|
# test data
|
|
dtype = 'float32'
|
|
x_data = torch.randn(2, 2, 3, 6)
|
|
m = torch.nn.Mish()
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
with tensorrt_llm.net_guard(net):
|
|
network = tensorrt_llm.default_trtnet()
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
mish = tensorrt_llm.layers.Mish()
|
|
output = mish.forward(x).trt_tensor
|
|
output.name = 'output'
|
|
network.mark_output(output)
|
|
|
|
# trt run
|
|
build_engine = EngineFromNetwork((builder.trt_builder, net.trt_network))
|
|
with TrtRunner(build_engine) as runner:
|
|
outputs = runner.infer(feed_dict={'x': x_data.numpy()})
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data)
|
|
|
|
# compare diff
|
|
np.testing.assert_allclose(ref.cpu().numpy(),
|
|
outputs['output'],
|
|
atol=1e-6)
|
|
|
|
# The activation memory usage baseline is acquired by `session.engine.device_memory_size_v2` and hardcoded here since it shouldn't change much across platforms if we fused mha successfully.
|
|
@parameterized.expand(
|
|
[
|
|
(
|
|
12, 512, 16, 64, 'float16', PositionEmbeddingType.alibi, False,
|
|
(402653184, 62914560)
|
|
), # TRT has gpu buffer management issues with fmha + alibi, so the baseline here is tested w./o. fused mha.
|
|
(128, 128, 12, 32, 'float16', PositionEmbeddingType.alibi, True,
|
|
(201326592, 62914560)),
|
|
(1, 200, 8, 128, 'float32', PositionEmbeddingType.alibi, False,
|
|
5017600),
|
|
(48, 30, 24, 80, 'float32', PositionEmbeddingType.alibi, True,
|
|
55296000),
|
|
(12, 512, 16, 64, 'float16', PositionEmbeddingType.learned_absolute,
|
|
False, (88113152, 402653184)),
|
|
(128, 128, 12, 32, 'float16',
|
|
PositionEmbeddingType.learned_absolute, True,
|
|
(88866816, 201326592)),
|
|
(1, 200, 8, 128, 'float32', PositionEmbeddingType.learned_absolute,
|
|
False, 5017600),
|
|
(48, 30, 24, 80, 'float32', PositionEmbeddingType.learned_absolute,
|
|
True, 55296000),
|
|
(2, 128, 4, 64, 'float16', PositionEmbeddingType.learned_absolute,
|
|
True, 35588608, True),
|
|
(2, 128, 4, 64, 'float32', PositionEmbeddingType.learned_absolute,
|
|
True, 36833280, True),
|
|
],
|
|
name_func=unittest_name_func)
|
|
def test_attention(self,
|
|
batch_size,
|
|
seq_len,
|
|
head_num,
|
|
head_size,
|
|
dtype,
|
|
pos_emb_type,
|
|
causal_mask,
|
|
act_mem_baseline: int | Tuple[int, int] | None = None,
|
|
use_plugin=False):
|
|
|
|
hidden_size = head_num * head_size
|
|
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
mean = 0.0
|
|
std_dev = 0.02 if dtype == "float32" else 0.005
|
|
|
|
hidden_states = torch.empty(size=[batch_size, seq_len, hidden_size],
|
|
dtype=torch_dtype,
|
|
device='cuda')
|
|
hidden_states.normal_(mean, std_dev)
|
|
|
|
#TODO: can change to random after torch ref support non padding format
|
|
context_lengths = torch.full([batch_size],
|
|
seq_len,
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
|
|
if use_plugin:
|
|
# Only generate 1 step
|
|
max_seq_len = seq_len + 1
|
|
|
|
# zero means "valid" token, one means invalid. Here since torch ref does not support mask, make it all valid.
|
|
host_past_key_value_lengths = torch.tensor([0] * batch_size,
|
|
dtype=torch.int32)
|
|
|
|
# the max kv cache length for each layer.
|
|
# single tensor since we only have 1 layer here.
|
|
host_max_attention_window_sizes = torch.tensor([max_seq_len],
|
|
dtype=torch.int32)
|
|
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
|
|
|
|
sequence_length = torch.full([batch_size],
|
|
seq_len,
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
# even in the the context phase, kv cache tensors can not be empty tensor for plugin, the actual shape info
|
|
# otherwise, there will be cublas execution error.
|
|
# are passed to plugin by the `sequence_length` tensor
|
|
kv_shape = (batch_size, 2, head_num, max_seq_len, head_size)
|
|
past_key_value = torch.randn(kv_shape,
|
|
dtype=torch_dtype,
|
|
device='cuda')
|
|
cache_indirection = torch.full((
|
|
batch_size,
|
|
1,
|
|
max_seq_len,
|
|
),
|
|
0,
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
|
|
host_request_types = torch.tensor([0] * batch_size,
|
|
dtype=torch.int32,
|
|
device='cpu')
|
|
|
|
perf_knob_tensor_size = 16
|
|
host_runtime_perf_knobs_tensor = torch.tensor([-1] *
|
|
perf_knob_tensor_size,
|
|
dtype=torch.int64,
|
|
device='cpu')
|
|
host_context_progress = torch.tensor([0],
|
|
dtype=torch.int64,
|
|
device='cpu')
|
|
|
|
q_weight = torch.empty(size=[hidden_size, hidden_size],
|
|
dtype=torch_dtype)
|
|
torch.nn.init.xavier_uniform_(q_weight)
|
|
|
|
# The initialization here is chosen to minimize computation after the
|
|
# QKV BMMs in order to reduce the amount of differences from FP accumulation.
|
|
# We set K and V weights to the identity matrix so that the input is copied
|
|
# without doing any accumulation. Additionally, we set the output projection
|
|
# to the identity for the same reason.
|
|
# The main purpose of these tests is to check the QK^T BMM + Softmax + SV BMM.
|
|
eye_weight = torch.eye(hidden_size, dtype=torch_dtype)
|
|
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
|
|
|
|
out_weight = eye_weight
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
net.plugin_config.to_legacy_setting()
|
|
if use_plugin:
|
|
net.plugin_config.gpt_attention_plugin = dtype
|
|
with tensorrt_llm.net_guard(net):
|
|
trt_hidden_states = Tensor(
|
|
name='hidden_states',
|
|
shape=hidden_states.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
context_lengths_tensor = Tensor(
|
|
name='context_lengths',
|
|
shape=context_lengths.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
|
|
if use_plugin:
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=host_request_types.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
past_key_value_tensor = Tensor(
|
|
name='past_key_value',
|
|
shape=tuple(past_key_value.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
sequence_length_tensor = Tensor(
|
|
name='sequence_length',
|
|
shape=tuple(sequence_length.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_past_key_value_lengths_tensor = Tensor(
|
|
name='host_past_key_value_lengths',
|
|
shape=tuple(host_past_key_value_lengths.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_max_attention_window_sizes_tensor = Tensor(
|
|
name='host_max_attention_window_sizes',
|
|
shape=tuple(host_max_attention_window_sizes.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_sink_token_length_tensor = Tensor(
|
|
name='host_sink_token_length',
|
|
shape=tuple(host_sink_token_length.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
cache_indirection_tensor = Tensor(
|
|
name='cache_indirection',
|
|
shape=tuple(cache_indirection.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_runtime_perf_knobs = Tensor(
|
|
name='host_runtime_perf_knobs',
|
|
shape=[16],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
|
host_context_progress_tensor = Tensor(
|
|
name='host_context_progress',
|
|
shape=[1],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
|
|
|
mask_type = tensorrt_llm.layers.AttentionMaskType.padding
|
|
if causal_mask:
|
|
mask_type = tensorrt_llm.layers.AttentionMaskType.causal
|
|
|
|
attn_layer = tensorrt_llm.layers.Attention(
|
|
local_layer_idx=0,
|
|
hidden_size=hidden_size,
|
|
num_attention_heads=head_num,
|
|
max_position_embeddings=seq_len,
|
|
attention_mask_type=mask_type,
|
|
position_embedding_type=pos_emb_type,
|
|
bias=False)
|
|
|
|
attn_layer.qkv.weight.value = np.ascontiguousarray(
|
|
qkv_weight.cpu().numpy().transpose([1, 0]))
|
|
attn_layer.dense.weight.value = np.ascontiguousarray(
|
|
out_weight.cpu().numpy().transpose([1, 0]))
|
|
input_tensor = trt_hidden_states
|
|
if use_plugin:
|
|
output, present_key_value = attn_layer(
|
|
input_tensor,
|
|
use_cache=True,
|
|
kv_cache_params=KeyValueCacheParams(
|
|
past_key_value=[past_key_value_tensor],
|
|
host_past_key_value_lengths=
|
|
host_past_key_value_lengths_tensor,
|
|
host_max_attention_window_sizes=
|
|
host_max_attention_window_sizes_tensor,
|
|
host_sink_token_length=host_sink_token_length_tensor,
|
|
cache_indirection=cache_indirection_tensor),
|
|
attention_params=AttentionParams(
|
|
sequence_length=sequence_length_tensor,
|
|
context_lengths=context_lengths_tensor,
|
|
host_request_types=host_request_types_tensor,
|
|
max_context_length=seq_len,
|
|
host_runtime_perf_knobs=host_runtime_perf_knobs,
|
|
host_context_progress=host_context_progress_tensor))
|
|
assert isinstance(output, Tensor)
|
|
output = output
|
|
present_key_value.mark_output(
|
|
'present_key_value', tensorrt_llm.str_dtype_to_trt(dtype))
|
|
else:
|
|
output = attn_layer(input_tensor)
|
|
output.mark_output('output', tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
builder_config = builder.create_builder_config(name='attention',
|
|
precision=dtype)
|
|
# Build engine
|
|
engine_buffer = builder.build_engine(net, builder_config)
|
|
session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
|
engine_buffer)
|
|
act_mem = session.engine.device_memory_size_v2
|
|
|
|
if act_mem_baseline != None:
|
|
if isinstance(act_mem_baseline, tuple):
|
|
act_mem_baseline = act_mem_baseline[
|
|
1] if trt.__version__.startswith(
|
|
"10") else act_mem_baseline[0]
|
|
if not pos_emb_type.is_alibi():
|
|
# TRT has gpu buffer management issues with fmha + alibi.
|
|
assert act_mem < act_mem_baseline * (1 + 0.1)
|
|
assert act_mem > act_mem_baseline * (
|
|
1 - 0.1
|
|
), f"The mr activation memory usage is better than baseline, please update the test_attention in test_layer.py. The outdated baseline is {act_mem_baseline}, and the new baseline is {act_mem}."
|
|
|
|
stream = torch.cuda.current_stream().cuda_stream
|
|
|
|
if use_plugin:
|
|
inputs = {
|
|
'hidden_states': hidden_states,
|
|
'past_key_value': past_key_value,
|
|
'sequence_length': sequence_length,
|
|
'host_past_key_value_lengths': host_past_key_value_lengths,
|
|
'host_max_attention_window_sizes':
|
|
host_max_attention_window_sizes,
|
|
'host_sink_token_length': host_sink_token_length,
|
|
'context_lengths': context_lengths,
|
|
'host_request_types': host_request_types,
|
|
'cache_indirection': cache_indirection,
|
|
'host_runtime_perf_knobs': host_runtime_perf_knobs_tensor,
|
|
'host_context_progress': host_context_progress
|
|
}
|
|
outputs = {
|
|
'output':
|
|
torch.empty(hidden_states.shape,
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda'),
|
|
'present_key_value':
|
|
past_key_value,
|
|
}
|
|
else:
|
|
inputs = {
|
|
'hidden_states': hidden_states,
|
|
'context_lengths': context_lengths,
|
|
}
|
|
outputs = {
|
|
'output':
|
|
torch.empty(hidden_states.shape,
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda'),
|
|
}
|
|
|
|
session.run(inputs=inputs, outputs=outputs, stream=stream)
|
|
torch.cuda.synchronize()
|
|
|
|
packed_torch_qkv = hidden_states.to("cuda") @ qkv_weight.to("cuda")
|
|
packed_torch_qkv = packed_torch_qkv.reshape(
|
|
[batch_size, seq_len, 3, head_num, head_size])
|
|
|
|
alibi_bias = None
|
|
if pos_emb_type == PositionEmbeddingType.alibi:
|
|
mask = torch.ones(size=[batch_size, seq_len], device="cuda")
|
|
alibi_bias = build_alibi_tensor(mask, head_num, torch.float32)
|
|
alibi_bias = alibi_bias.reshape([batch_size, head_num, 1, seq_len])
|
|
|
|
mha_out, _ = attention_qkvpacked_ref(packed_torch_qkv,
|
|
causal=causal_mask,
|
|
upcast=False,
|
|
bias=alibi_bias)
|
|
torch_out = mha_out.reshape([batch_size, seq_len, hidden_size])
|
|
|
|
trt_output = outputs['output']
|
|
|
|
a_tol = 5e-5 if (dtype == "float32" and not use_plugin) else 2e-3
|
|
np.testing.assert_allclose(torch_out.cpu().numpy(),
|
|
trt_output.cpu().numpy(),
|
|
atol=a_tol,
|
|
verbose=True)
|
|
|
|
@parameterized.expand(list(
|
|
product([3], [16], [1], [1024], [16], ['context', 'generation'],
|
|
["float32", "float16", "bfloat16"], [True, False],
|
|
[True, False])),
|
|
name_func=unittest_name_func)
|
|
def test_mamba(self, batch_size, in_seq_len, out_seq_len, d_model, d_state,
|
|
req_type, dtype, remove_padding, use_plugin):
|
|
|
|
if not use_plugin and remove_padding:
|
|
pytest.skip(
|
|
"Skipping remove input padding without mamba conv1d plugin")
|
|
|
|
# configs
|
|
device = "cuda"
|
|
d_conv = 4
|
|
expand = 2
|
|
dt_rank = "auto"
|
|
bias = False
|
|
d_inner = int(expand * d_model)
|
|
seqlen_offset = 0 if req_type == 'context' else in_seq_len
|
|
seq_len = in_seq_len if req_type == 'context' else out_seq_len
|
|
|
|
# test data
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
mean = 0.0
|
|
std_dev = 0.1 if dtype == "float32" else 0.05
|
|
|
|
if req_type == 'context':
|
|
last_token_ids = torch.randint(1,
|
|
in_seq_len + 1,
|
|
size=(batch_size, ),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
last_token_ids[0] = in_seq_len
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
else:
|
|
last_token_ids = torch.ones(size=[batch_size],
|
|
dtype=torch.int32,
|
|
device=device)
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
|
|
if use_plugin:
|
|
trt_conv_state_shape = [batch_size, d_conv - 1, d_inner]
|
|
conv_indices = torch.arange(0,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device).view([1, d_conv - 1, 1])
|
|
else:
|
|
trt_conv_state_shape = [batch_size, d_inner, d_conv - 1]
|
|
conv_indices = torch.arange(0,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device).view([1, 1, d_conv - 1])
|
|
offsets = last_token_ids.view([batch_size, 1, 1])
|
|
conv_indices = conv_indices.expand(trt_conv_state_shape) + offsets
|
|
|
|
if remove_padding:
|
|
last_token_ids = torch.cumsum(last_token_ids,
|
|
dim=0,
|
|
dtype=torch.int32).to(device)
|
|
total_num_tokens = last_token_ids[batch_size - 1]
|
|
else:
|
|
total_num_tokens = batch_size * seq_len
|
|
|
|
if remove_padding:
|
|
hidden_states = torch.empty(size=[total_num_tokens, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
output = torch.zeros(size=[total_num_tokens, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
hidden_states = torch.empty(size=[batch_size, seq_len, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
output = torch.zeros(size=[batch_size, seq_len, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
hidden_states.normal_(mean, std_dev)
|
|
if req_type == 'context':
|
|
conv_state = torch.zeros(size=[batch_size, d_inner, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
conv_state = torch.randn(size=[batch_size, d_inner, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
if req_type == 'context':
|
|
ssm_state = torch.empty(size=[batch_size, d_state, d_inner],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
ssm_state = torch.randn(size=[batch_size, d_state, d_inner],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
|
|
batch_size,
|
|
dtype=torch.int32)
|
|
|
|
present_conv_state = torch.zeros(size=trt_conv_state_shape,
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
hidden_states_ref = hidden_states.detach().clone()
|
|
out_ref = output.detach().clone()
|
|
if req_type == 'context':
|
|
conv_state_ref = torch.zeros(size=[batch_size, d_inner, d_conv],
|
|
dtype=torch_dtype,
|
|
device=device).detach()
|
|
else:
|
|
conv_state_ref = torch.concat(
|
|
(torch.zeros(size=[batch_size, d_inner, 1],
|
|
dtype=torch_dtype,
|
|
device=device), conv_state),
|
|
dim=2).detach().clone()
|
|
ssm_state_ref = ssm_state.detach().clone()
|
|
|
|
# get torch layer
|
|
mamba_torch = mamba_ref(d_model,
|
|
d_state,
|
|
d_conv,
|
|
expand,
|
|
dt_rank,
|
|
True,
|
|
bias,
|
|
device=device,
|
|
dtype=torch_dtype)
|
|
|
|
# init weights
|
|
for module in mamba_torch.modules():
|
|
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
|
|
if module.bias is not None:
|
|
torch.nn.init.normal_(module.bias, std=std_dev)
|
|
torch.nn.init.normal_(module.weight, std=std_dev)
|
|
|
|
A = -torch.rand(d_state, d_inner, device=device) - 1.0
|
|
D = torch.randn(d_inner, device=device)
|
|
dt_bias = torch.rand(d_inner, device=device) - 4.0
|
|
|
|
mamba_torch.A.data = A.detach().clone()
|
|
mamba_torch.D.data = D.detach().clone()
|
|
mamba_torch.dt_proj.bias.data = dt_bias.detach().clone()
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
net = builder.create_network()
|
|
if use_plugin:
|
|
net.plugin_config.mamba_conv1d_plugin = dtype
|
|
else:
|
|
net.plugin_config.mamba_conv1d_plugin = None
|
|
if remove_padding:
|
|
net.plugin_config.remove_input_padding = True
|
|
else:
|
|
net.plugin_config.remove_input_padding = False
|
|
net.plugin_config.paged_state = False
|
|
|
|
with tensorrt_llm.net_guard(net):
|
|
hidden_states_tensor = Tensor(
|
|
name='hidden_states',
|
|
shape=hidden_states.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
conv_state_tensor = Tensor(
|
|
name='conv_state',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
ssm_state_tensor = Tensor(
|
|
name='ssm_state',
|
|
shape=ssm_state.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=host_request_types.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
last_token_ids_tensor = Tensor(
|
|
name='last_token_ids',
|
|
shape=last_token_ids.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_lengths_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=host_context_lengths.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
conv_indices_tensor = Tensor(
|
|
name='conv_indices',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
mamba_layer = tensorrt_llm.layers.Mamba(d_model=d_model,
|
|
d_inner=d_inner,
|
|
d_state=d_state,
|
|
d_conv=d_conv,
|
|
dt_rank=dt_rank,
|
|
bias=bias,
|
|
dtype=dtype)
|
|
mamba_layer.A.value = torch_to_numpy(A.detach().cpu())
|
|
mamba_layer.D.value = torch_to_numpy(D.detach().cpu())
|
|
mamba_layer.dt_bias.value = torch_to_numpy(dt_bias.detach().cpu())
|
|
mamba_layer.in_proj_x.weight.value = torch_to_numpy(
|
|
mamba_torch.in_proj.weight[
|
|
0:d_inner,
|
|
].detach().cpu())
|
|
mamba_layer.in_proj_z.weight.value = torch_to_numpy(
|
|
mamba_torch.in_proj.weight[
|
|
d_inner:,
|
|
].detach().cpu())
|
|
mamba_layer.out_proj.weight.value = torch_to_numpy(
|
|
mamba_torch.out_proj.weight.detach().cpu())
|
|
if bias:
|
|
mamba_layer.in_proj_x.bias.value = torch_to_numpy(
|
|
mamba_torch.in_proj.bias[
|
|
0:d_inner,
|
|
].detach().cpu())
|
|
mamba_layer.in_proj_z.bias.value = torch_to_numpy(
|
|
mamba_torch.in_proj.bias[
|
|
d_inner:,
|
|
].detach().cpu())
|
|
mamba_layer.out_proj.bias.value = torch_to_numpy(
|
|
mamba_torch.out_proj.bias.detach().cpu())
|
|
mamba_layer.conv1d.weight.value = torch_to_numpy(
|
|
mamba_torch.conv1d.weight.detach().unsqueeze(3).cpu())
|
|
mamba_layer.conv1d.bias.value = torch_to_numpy(
|
|
mamba_torch.conv1d.bias.detach().cpu())
|
|
mamba_layer.x_proj.weight.value = torch_to_numpy(
|
|
mamba_torch.x_proj.weight.detach().cpu())
|
|
mamba_layer.dt_proj.weight.value = torch_to_numpy(
|
|
mamba_torch.dt_proj.weight.detach().cpu())
|
|
|
|
outputs = mamba_layer(
|
|
hidden_states_tensor,
|
|
conv_state_tensor,
|
|
ssm_state_tensor,
|
|
host_request_types_tensor,
|
|
last_token_ids_tensor,
|
|
host_context_lengths=host_context_lengths_tensor,
|
|
conv_indices=conv_indices_tensor)
|
|
net._mark_output(outputs[0],
|
|
'output',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[1],
|
|
'present_conv_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[2],
|
|
'present_ssm_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
if use_plugin:
|
|
trt_conv_state = conv_state.permute(0, 2, 1).contiguous()
|
|
else:
|
|
trt_conv_state = conv_state.clone().detach()
|
|
trt_conv_indices = conv_indices.clone().detach()
|
|
# trt run
|
|
inputs = {
|
|
'hidden_states': hidden_states,
|
|
'conv_state': trt_conv_state,
|
|
'ssm_state': ssm_state,
|
|
'host_request_types': host_request_types,
|
|
'last_token_ids': last_token_ids,
|
|
'host_context_lengths': host_context_lengths,
|
|
'conv_indices': trt_conv_indices,
|
|
}
|
|
outputs = {
|
|
'output': output,
|
|
'present_conv_state': present_conv_state,
|
|
'present_ssm_state': ssm_state,
|
|
}
|
|
|
|
stream = torch.cuda.current_stream()
|
|
builder_config = builder.create_builder_config(name='mamba',
|
|
precision=dtype)
|
|
engine = builder.build_engine(net, builder_config)
|
|
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
|
|
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
|
|
# pytorch run
|
|
out_ref, conv_state_ref, ssm_state_ref = mamba_torch(
|
|
hidden_states_ref, last_token_ids, conv_state_ref, ssm_state_ref,
|
|
remove_padding, batch_size, seqlen_offset)
|
|
|
|
dtype_atol = {"float16": 1e-2, "float32": 5e-3, "bfloat16": 5e-2}
|
|
|
|
if not remove_padding:
|
|
# get out_mask
|
|
if req_type == 'context':
|
|
out_mask = torch.zeros(batch_size, seq_len, device=device)
|
|
for i in range(batch_size):
|
|
for j in range(last_token_ids[i]):
|
|
out_mask[i, j] = 1
|
|
out_mask = out_mask.unsqueeze(2).expand(
|
|
[batch_size, seq_len, d_model])
|
|
else:
|
|
out_mask = torch.ones(batch_size,
|
|
seq_len,
|
|
d_model,
|
|
device=device)
|
|
|
|
# compare out diff
|
|
out_ref = (out_ref * out_mask).detach().to(
|
|
torch.float32).cpu().numpy()
|
|
outputs['output'][out_mask == 0] = 0
|
|
else:
|
|
out_ref = out_ref.detach().to(torch.float32).cpu().numpy()
|
|
|
|
out_trt_llm = outputs['output'].to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(out_ref, out_trt_llm, atol=dtype_atol[dtype])
|
|
|
|
# compare conv state diff
|
|
conv_state_ref = conv_state_ref[:, :, 1:].detach().to(
|
|
torch.float32).cpu().numpy()
|
|
conv_state_trt_llm = outputs['present_conv_state']
|
|
if use_plugin:
|
|
conv_state_trt_llm = conv_state_trt_llm.permute(0, 2,
|
|
1).contiguous()
|
|
conv_state_trt_llm = conv_state_trt_llm.to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(conv_state_ref,
|
|
conv_state_trt_llm,
|
|
atol=dtype_atol[dtype])
|
|
|
|
# compare ssm state diff
|
|
ssm_state_ref = ssm_state_ref.detach().to(torch.float32).cpu().numpy()
|
|
ssm_state_trt_llm = outputs['present_ssm_state']
|
|
ssm_state_trt_llm = ssm_state_trt_llm.to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(ssm_state_ref,
|
|
ssm_state_trt_llm,
|
|
atol=dtype_atol[dtype])
|
|
|
|
@parameterized.expand(
|
|
# simple tests
|
|
list(
|
|
product([3], [16], [1], [1024], [128], [64], [256], [1, 4],
|
|
['context', 'generation'],
|
|
["float32", "float16", "bfloat16"], [True, False],
|
|
[True, False])) +
|
|
# P=8x and H=2x
|
|
list(
|
|
product([2, 4, 8, 16], [16], [1], [160, 320, 640], [128], [80],
|
|
[256], [1], ['context', 'generation'], ["float16"], [True],
|
|
[True])),
|
|
name_func=unittest_name_func)
|
|
def test_mamba2(self, batch_size, in_seq_len, out_seq_len, d_model, d_state,
|
|
headdim, chunk_size, ngroups, req_type, dtype,
|
|
remove_padding, use_plugin):
|
|
|
|
if not use_plugin and remove_padding:
|
|
pytest.skip(
|
|
"Skipping remove input padding without mamba conv1d plugin")
|
|
if dtype == 'float32' and req_type == 'context':
|
|
pytest.skip(
|
|
"Mamba2 layer only support float16 and bfloat16 in context phase"
|
|
)
|
|
|
|
# configs
|
|
device = "cuda"
|
|
d_conv = 4
|
|
expand = 2
|
|
bias = False
|
|
rmsnorm = True
|
|
d_inner = int(expand * d_model)
|
|
nheads = d_inner // headdim
|
|
conv_dim = d_inner + 2 * ngroups * d_state
|
|
seqlen_offset = 0 if req_type == 'context' else in_seq_len
|
|
seq_len = in_seq_len if req_type == 'context' else out_seq_len
|
|
|
|
# test data
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
mean = 0.0
|
|
std_dev = 0.05 if dtype == "float32" else 0.02
|
|
torch.random.manual_seed(0)
|
|
|
|
if req_type == 'context':
|
|
last_token_ids = torch.randint(1,
|
|
in_seq_len + 1,
|
|
size=(batch_size, ),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
last_token_ids[0] = in_seq_len
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
else:
|
|
last_token_ids = torch.ones(size=[batch_size],
|
|
dtype=torch.int32,
|
|
device=device)
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
|
|
if use_plugin:
|
|
trt_conv_state_shape = [batch_size, d_conv - 1, conv_dim]
|
|
conv_indices = torch.arange(0,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device).view([1, d_conv - 1, 1])
|
|
else:
|
|
trt_conv_state_shape = [batch_size, conv_dim, d_conv - 1]
|
|
conv_indices = torch.arange(0,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device).view([1, 1, d_conv - 1])
|
|
offsets = last_token_ids.view([batch_size, 1, 1])
|
|
conv_indices = conv_indices.expand(trt_conv_state_shape) + offsets
|
|
|
|
if remove_padding:
|
|
last_token_ids = torch.cumsum(last_token_ids,
|
|
dim=0,
|
|
dtype=torch.int32).to(device)
|
|
total_num_tokens = last_token_ids[batch_size - 1]
|
|
else:
|
|
total_num_tokens = batch_size * seq_len
|
|
|
|
if remove_padding:
|
|
hidden_states = torch.empty(size=[total_num_tokens, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
output = torch.zeros(size=[total_num_tokens, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
hidden_states = torch.empty(size=[batch_size, seq_len, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
output = torch.zeros(size=[batch_size, seq_len, d_model],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
hidden_states.normal_(mean, std_dev)
|
|
|
|
if req_type == 'context':
|
|
conv_state = torch.zeros(size=[batch_size, conv_dim, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
conv_state = torch.randn(size=[batch_size, conv_dim, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
if req_type == 'context':
|
|
ssm_state = torch.empty(size=[batch_size, nheads, d_state, headdim],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
ssm_state = torch.randn(size=[batch_size, nheads, d_state, headdim],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
|
|
batch_size,
|
|
dtype=torch.int32)
|
|
|
|
present_conv_state = torch.zeros(size=trt_conv_state_shape,
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
hidden_states_ref = hidden_states.detach().clone()
|
|
out_ref = output.detach().clone()
|
|
if req_type == 'context':
|
|
conv_state_ref = torch.zeros(size=[batch_size, conv_dim, d_conv],
|
|
dtype=torch_dtype,
|
|
device=device).detach()
|
|
else:
|
|
conv_state_ref = torch.concat(
|
|
(torch.zeros(size=[batch_size, conv_dim, 1],
|
|
dtype=torch_dtype,
|
|
device=device), conv_state),
|
|
dim=2).detach().clone()
|
|
ssm_state_ref = ssm_state.detach().clone()
|
|
|
|
# get torch layer
|
|
mamba2_torch = mamba2_ref(d_model,
|
|
d_state,
|
|
d_conv,
|
|
expand,
|
|
headdim,
|
|
ngroups,
|
|
chunk_size,
|
|
True,
|
|
bias,
|
|
rmsnorm=rmsnorm,
|
|
device=device,
|
|
dtype=torch_dtype)
|
|
|
|
# init weights
|
|
for module in mamba2_torch.modules():
|
|
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
|
|
if module.bias is not None:
|
|
torch.nn.init.normal_(module.bias, std=std_dev)
|
|
torch.nn.init.normal_(module.weight, std=std_dev)
|
|
|
|
A = -torch.rand(nheads, device=device) - 1.0
|
|
D = torch.randn(nheads, device=device)
|
|
dt_bias = torch.rand(nheads, device=device) - 4.0
|
|
norm_weight = torch.randn(d_inner, device=device)
|
|
|
|
mamba2_torch.A.data = A.detach().clone()
|
|
mamba2_torch.D.data = D.detach().clone()
|
|
mamba2_torch.dt_bias.data = dt_bias.detach().clone()
|
|
if rmsnorm:
|
|
mamba2_torch.norm_weight.data = norm_weight.detach().clone()
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
if use_plugin:
|
|
net.plugin_config.mamba_conv1d_plugin = dtype
|
|
net.plugin_config.gemm_plugin = dtype
|
|
else:
|
|
net.plugin_config.mamba_conv1d_plugin = None
|
|
net.plugin_config.gemm_plugin = None
|
|
if remove_padding:
|
|
net.plugin_config.remove_input_padding = True
|
|
else:
|
|
net.plugin_config.remove_input_padding = False
|
|
net.plugin_config.paged_state = False
|
|
|
|
with tensorrt_llm.net_guard(net):
|
|
hidden_states_tensor = Tensor(
|
|
name='hidden_states',
|
|
shape=hidden_states.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
conv_state_tensor = Tensor(
|
|
name='conv_state',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
ssm_state_tensor = Tensor(
|
|
name='ssm_state',
|
|
shape=ssm_state.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=host_request_types.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
last_token_ids_tensor = Tensor(
|
|
name='last_token_ids',
|
|
shape=last_token_ids.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_lengths_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=host_context_lengths.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
conv_indices_tensor = Tensor(
|
|
name='conv_indices',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
mamba2_layer = tensorrt_llm.layers.Mamba2(d_model=d_model,
|
|
d_inner=d_inner,
|
|
d_state=d_state,
|
|
d_conv=d_conv,
|
|
headdim=headdim,
|
|
ngroups=ngroups,
|
|
chunk_size=chunk_size,
|
|
bias=bias,
|
|
rmsnorm=rmsnorm,
|
|
dtype=dtype)
|
|
mamba2_layer.A.value = torch_to_numpy(A.detach().cpu())
|
|
mamba2_layer.D.value = torch_to_numpy(D.detach().cpu())
|
|
mamba2_layer.dt_bias.value = torch_to_numpy(dt_bias.detach().cpu())
|
|
mamba2_layer.norm.weight.value = torch_to_numpy(
|
|
norm_weight.detach().cpu())
|
|
mamba2_layer.in_proj.weight.value = torch_to_numpy(
|
|
mamba2_torch.in_proj.weight.detach().cpu())
|
|
mamba2_layer.out_proj.weight.value = torch_to_numpy(
|
|
mamba2_torch.out_proj.weight.detach().cpu())
|
|
if bias:
|
|
mamba2_layer.in_proj.bias.value = torch_to_numpy(
|
|
mamba2_torch.in_proj.bias.detach().cpu())
|
|
mamba2_layer.out_proj.bias.value = torch_to_numpy(
|
|
mamba2_torch.out_proj.bias.detach().cpu())
|
|
mamba2_layer.conv1d.weight.value = torch_to_numpy(
|
|
mamba2_torch.conv1d.weight.detach().unsqueeze(3).cpu())
|
|
mamba2_layer.conv1d.bias.value = torch_to_numpy(
|
|
mamba2_torch.conv1d.bias.detach().cpu())
|
|
|
|
outputs = mamba2_layer(
|
|
hidden_states_tensor,
|
|
conv_state_tensor,
|
|
ssm_state_tensor,
|
|
host_request_types_tensor,
|
|
last_token_ids_tensor,
|
|
host_context_lengths=host_context_lengths_tensor,
|
|
conv_indices=conv_indices_tensor)
|
|
net._mark_output(outputs[0],
|
|
'output',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[1],
|
|
'present_conv_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[2],
|
|
'present_ssm_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
if use_plugin:
|
|
trt_conv_state = conv_state.permute(0, 2, 1).contiguous()
|
|
else:
|
|
trt_conv_state = conv_state.clone().detach()
|
|
trt_conv_indices = conv_indices.clone().detach()
|
|
# trt run
|
|
inputs = {
|
|
'hidden_states': hidden_states,
|
|
'conv_state': trt_conv_state,
|
|
'ssm_state': ssm_state,
|
|
'host_request_types': host_request_types,
|
|
'last_token_ids': last_token_ids,
|
|
'host_context_lengths': host_context_lengths,
|
|
'conv_indices': trt_conv_indices,
|
|
}
|
|
outputs = {
|
|
'output': output,
|
|
'present_conv_state': present_conv_state,
|
|
'present_ssm_state': ssm_state,
|
|
}
|
|
|
|
stream = torch.cuda.current_stream()
|
|
builder_config = builder.create_builder_config(name='mamba2',
|
|
precision=dtype)
|
|
engine = builder.build_engine(net, builder_config)
|
|
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
|
|
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
|
|
# pytorch run
|
|
out_ref, conv_state_ref, ssm_state_ref = mamba2_torch(
|
|
hidden_states_ref, last_token_ids, conv_state_ref, ssm_state_ref,
|
|
remove_padding, batch_size, seqlen_offset)
|
|
|
|
dtype_atol = {"float16": 1e-2, "float32": 5e-3, "bfloat16": 5e-2}
|
|
|
|
if not remove_padding:
|
|
# get out_mask
|
|
if req_type == 'context':
|
|
out_mask = torch.zeros(batch_size, seq_len, device=device)
|
|
for i in range(batch_size):
|
|
for j in range(last_token_ids[i]):
|
|
out_mask[i, j] = 1
|
|
out_mask = out_mask.unsqueeze(2).expand(
|
|
[batch_size, seq_len, d_model])
|
|
else:
|
|
out_mask = torch.ones(batch_size,
|
|
seq_len,
|
|
d_model,
|
|
device=device)
|
|
|
|
# compare out diff
|
|
out_ref = (out_ref * out_mask).detach().to(
|
|
torch.float32).cpu().numpy()
|
|
outputs['output'][out_mask == 0] = 0
|
|
else:
|
|
out_ref = out_ref.detach().to(torch.float32).cpu().numpy()
|
|
|
|
out_trt_llm = outputs['output'].to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(out_ref, out_trt_llm, atol=dtype_atol[dtype])
|
|
|
|
# compare conv state diff
|
|
conv_state_ref = conv_state_ref[:, :, 1:].detach().to(
|
|
torch.float32).cpu().numpy()
|
|
conv_state_trt_llm = outputs['present_conv_state']
|
|
if use_plugin:
|
|
conv_state_trt_llm = conv_state_trt_llm.permute(0, 2,
|
|
1).contiguous()
|
|
conv_state_trt_llm = conv_state_trt_llm.to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(conv_state_ref,
|
|
conv_state_trt_llm,
|
|
atol=dtype_atol[dtype])
|
|
|
|
# compare ssm state diff
|
|
ssm_state_ref = ssm_state_ref.detach().to(torch.float32).cpu().numpy()
|
|
ssm_state_trt_llm = outputs['present_ssm_state']
|
|
ssm_state_trt_llm = ssm_state_trt_llm.to(torch.float32).cpu().numpy()
|
|
np.testing.assert_allclose(ssm_state_ref,
|
|
ssm_state_trt_llm,
|
|
atol=dtype_atol[dtype])
|
|
|
|
@parameterized.expand(list(
|
|
product([3], [16], [1], [1280], [1280], [10], ['context', 'generation'],
|
|
["float32", "float16", "bfloat16"], [True, False],
|
|
[True, False], [True, False])),
|
|
name_func=unittest_name_func)
|
|
def test_recurrent(self, batch_size, in_seq_len, out_seq_len, width,
|
|
lru_width, num_heads, req_type, dtype, remove_padding,
|
|
use_plugin, use_fused_rg_lru):
|
|
|
|
if not use_plugin and remove_padding:
|
|
pytest.skip(
|
|
"Skipping remove input padding without mamba conv1d plugin")
|
|
|
|
# configs
|
|
device = "cuda"
|
|
d_conv = 4
|
|
seq_len = in_seq_len if req_type == 'context' else out_seq_len
|
|
torch.random.manual_seed(0)
|
|
|
|
# test data
|
|
torch_dtype = str_dtype_to_torch(dtype)
|
|
mean = 0.0
|
|
std_dev = 0.1 if dtype == "float32" else 0.02
|
|
|
|
if req_type == 'context':
|
|
last_token_ids = torch.randint(1,
|
|
in_seq_len + 1,
|
|
size=(batch_size, ),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
last_token_ids[0] = in_seq_len
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
segment_pos = torch.arange(0,
|
|
in_seq_len,
|
|
dtype=torch.int32,
|
|
device=device).unsqueeze(0).repeat(
|
|
batch_size, 1)
|
|
segment_pos = segment_pos * (
|
|
segment_pos < last_token_ids.unsqueeze(1)).type(torch.int32)
|
|
else:
|
|
last_token_ids = torch.ones(size=[batch_size],
|
|
dtype=torch.int32,
|
|
device=device)
|
|
host_context_lengths = last_token_ids.detach().clone().cpu()
|
|
segment_pos = torch.randint(1,
|
|
in_seq_len + 1,
|
|
size=(batch_size, ),
|
|
dtype=torch.int32,
|
|
device=device)
|
|
segment_pos[0] = in_seq_len
|
|
segment_pos = segment_pos.unsqueeze(1)
|
|
|
|
conv_indices = torch.arange(0,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device)
|
|
conv_indices_ref = torch.arange(-1,
|
|
d_conv - 1,
|
|
dtype=torch.int32,
|
|
device=device).view([1, 1, d_conv])
|
|
if use_plugin:
|
|
trt_conv_state_shape = [batch_size, d_conv - 1, lru_width]
|
|
conv_indices = conv_indices.view([1, d_conv - 1, 1])
|
|
else:
|
|
trt_conv_state_shape = [batch_size, lru_width, d_conv - 1]
|
|
conv_indices = conv_indices.view([1, 1, d_conv - 1])
|
|
offsets = last_token_ids.view([batch_size, 1, 1])
|
|
conv_indices = conv_indices.expand(trt_conv_state_shape) + offsets
|
|
conv_indices_ref = conv_indices_ref.expand(
|
|
[batch_size, lru_width, d_conv]) + offsets
|
|
|
|
if remove_padding:
|
|
last_token_ids = torch.cumsum(last_token_ids,
|
|
dim=0,
|
|
dtype=torch.int32).to(device)
|
|
total_num_tokens = last_token_ids[batch_size - 1]
|
|
hidden_states_shape = [total_num_tokens, width]
|
|
else:
|
|
total_num_tokens = batch_size * seq_len
|
|
hidden_states_shape = [batch_size, seq_len, width]
|
|
|
|
hidden_states = torch.empty(size=hidden_states_shape,
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
hidden_states.normal_(mean, std_dev)
|
|
output = torch.zeros(size=hidden_states_shape,
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
if req_type == 'context':
|
|
conv_state = torch.zeros(size=[batch_size, lru_width, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
else:
|
|
conv_state = torch.randn(size=[batch_size, lru_width, d_conv - 1],
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
if req_type == 'context':
|
|
lru_state = torch.empty(size=[batch_size, lru_width],
|
|
dtype=torch.float32,
|
|
device=device)
|
|
else:
|
|
lru_state = torch.randn(size=[batch_size, lru_width],
|
|
dtype=torch.float32,
|
|
device=device)
|
|
|
|
host_request_types = torch.tensor([0 if req_type == 'context' else 1] *
|
|
batch_size,
|
|
dtype=torch.int32)
|
|
|
|
present_conv_state = torch.zeros(size=trt_conv_state_shape,
|
|
dtype=torch_dtype,
|
|
device=device)
|
|
|
|
hidden_states_ref = hidden_states.detach().clone()
|
|
out_ref = output.detach().clone()
|
|
if req_type == 'context':
|
|
conv_state_ref = None
|
|
else:
|
|
conv_state_ref = torch.concat(
|
|
(torch.zeros(size=[batch_size, lru_width, 1],
|
|
dtype=torch_dtype,
|
|
device=device), conv_state),
|
|
dim=2).detach().clone()
|
|
lru_state_ref = lru_state.detach().clone()
|
|
|
|
# get torch layer
|
|
recurrent_torch = recurrent_ref(width,
|
|
lru_width,
|
|
num_heads,
|
|
d_conv,
|
|
device=device,
|
|
dtype=torch_dtype)
|
|
|
|
# init weights
|
|
for module in recurrent_torch.modules():
|
|
if isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)):
|
|
if module.bias is not None:
|
|
torch.nn.init.normal_(module.bias, std=std_dev)
|
|
torch.nn.init.normal_(module.weight, std=std_dev)
|
|
|
|
# init recurrent_param
|
|
min_rad, max_rad, eps = 0.9, 0.999, 1e-8
|
|
recurrent_param = torch.randn(lru_width,
|
|
device=device,
|
|
dtype=torch_dtype)
|
|
recurrent_param.uniform_(min_rad**2 + eps, max_rad**2 + eps)
|
|
recurrent_param.log_().mul_(0.5)
|
|
recurrent_param.neg_().exp_().sub_(1.0).log_()
|
|
recurrent_torch.recurrent_param.data = recurrent_param.detach().clone()
|
|
|
|
def fuse_rg_lru(recurrent_layer):
|
|
fused_layer = tensorrt_llm.layers.FusedRgLru(lru_width=lru_width,
|
|
num_heads=num_heads,
|
|
dtype=dtype)
|
|
fused_layer.gate.weight.value = np.concatenate([
|
|
recurrent_layer.rg_lru.input_gate.weight.raw_value,
|
|
recurrent_layer.rg_lru.recurrent_gate.weight.raw_value
|
|
],
|
|
axis=-1)
|
|
fused_layer.gate.bias.value = np.concatenate([
|
|
recurrent_layer.rg_lru.input_gate.bias.raw_value,
|
|
recurrent_layer.rg_lru.recurrent_gate.bias.raw_value
|
|
],
|
|
axis=-1)
|
|
fused_layer.recurrent_param.value = recurrent_layer.rg_lru.recurrent_param.raw_value
|
|
recurrent_layer.rg_lru = fused_layer
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
net = builder.create_network()
|
|
if use_plugin:
|
|
net.plugin_config.mamba_conv1d_plugin = dtype
|
|
net.plugin_config.gemm_plugin = dtype
|
|
else:
|
|
net.plugin_config.mamba_conv1d_plugin = None
|
|
net.plugin_config.gemm_plugin = None
|
|
if remove_padding:
|
|
net.plugin_config.remove_input_padding = True
|
|
else:
|
|
net.plugin_config.remove_input_padding = False
|
|
net.plugin_config.paged_state = False
|
|
|
|
with tensorrt_llm.net_guard(net):
|
|
hidden_states_tensor = Tensor(
|
|
name='hidden_states',
|
|
shape=hidden_states.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
conv_state_tensor = Tensor(
|
|
name='conv_state',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
lru_state_tensor = Tensor(
|
|
name='lru_state',
|
|
shape=lru_state.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=host_request_types.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
last_token_ids_tensor = Tensor(
|
|
name='last_token_ids',
|
|
shape=last_token_ids.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_lengths_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=host_context_lengths.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
conv_indices_tensor = Tensor(
|
|
name='conv_indices',
|
|
shape=trt_conv_state_shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
recurrent_layer = tensorrt_llm.layers.Recurrent(width=width,
|
|
lru_width=lru_width,
|
|
d_conv=d_conv,
|
|
num_heads=num_heads,
|
|
dtype=dtype)
|
|
recurrent_layer.rg_lru.recurrent_param.value = torch_to_numpy(
|
|
recurrent_param.detach().cpu())
|
|
recurrent_layer.linear_x.weight.value = torch_to_numpy(
|
|
recurrent_torch.linear_x.weight.detach().cpu())
|
|
recurrent_layer.linear_x.bias.value = torch_to_numpy(
|
|
recurrent_torch.linear_x.bias.detach().cpu())
|
|
recurrent_layer.linear_y.weight.value = torch_to_numpy(
|
|
recurrent_torch.linear_y.weight.detach().cpu())
|
|
recurrent_layer.y_bias.value = torch_to_numpy(
|
|
recurrent_torch.y_bias.squeeze().detach().cpu())
|
|
recurrent_layer.conv1d.weight.value = torch_to_numpy(
|
|
recurrent_torch.conv1d.weight.detach().unsqueeze(3).cpu())
|
|
recurrent_layer.conv1d.bias.value = torch_to_numpy(
|
|
recurrent_torch.conv1d.bias.detach().cpu())
|
|
recurrent_layer.rg_lru.input_gate.weight.value = torch_to_numpy(
|
|
recurrent_torch.input_gate.w.detach().cpu())
|
|
recurrent_layer.rg_lru.input_gate.bias.value = torch_to_numpy(
|
|
recurrent_torch.input_gate.b.detach().cpu())
|
|
recurrent_layer.rg_lru.recurrent_gate.weight.value = torch_to_numpy(
|
|
recurrent_torch.recurrent_gate.w.detach().cpu())
|
|
recurrent_layer.rg_lru.recurrent_gate.bias.value = torch_to_numpy(
|
|
recurrent_torch.recurrent_gate.b.detach().cpu())
|
|
recurrent_layer.linear_out.weight.value = torch_to_numpy(
|
|
recurrent_torch.linear_out.weight.detach().cpu())
|
|
recurrent_layer.linear_out.bias.value = torch_to_numpy(
|
|
recurrent_torch.linear_out.bias.detach().cpu())
|
|
|
|
if use_fused_rg_lru:
|
|
fuse_rg_lru(recurrent_layer)
|
|
|
|
outputs = recurrent_layer(
|
|
hidden_states_tensor,
|
|
conv_state_tensor,
|
|
lru_state_tensor,
|
|
host_request_types_tensor,
|
|
last_token_ids_tensor,
|
|
host_context_lengths=host_context_lengths_tensor,
|
|
conv_indices=conv_indices_tensor)
|
|
net._mark_output(outputs[0],
|
|
'output',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[1],
|
|
'present_conv_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
net._mark_output(outputs[2],
|
|
'present_lru_state',
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
|
|
|
|
if use_plugin:
|
|
trt_conv_state = conv_state.permute(0, 2, 1).contiguous()
|
|
else:
|
|
trt_conv_state = conv_state.clone().detach()
|
|
trt_conv_indices = conv_indices.clone().detach()
|
|
# trt run
|
|
inputs = {
|
|
'hidden_states': hidden_states,
|
|
'conv_state': trt_conv_state,
|
|
'lru_state': lru_state,
|
|
'host_request_types': host_request_types,
|
|
'last_token_ids': last_token_ids,
|
|
'host_context_lengths': host_context_lengths,
|
|
'conv_indices': trt_conv_indices,
|
|
}
|
|
outputs = {
|
|
'output': output,
|
|
'present_conv_state': present_conv_state,
|
|
'present_lru_state': lru_state,
|
|
}
|
|
|
|
stream = torch.cuda.current_stream()
|
|
builder_config = builder.create_builder_config(name='recurrent',
|
|
precision=dtype)
|
|
engine = builder.build_engine(net, builder_config)
|
|
session = tensorrt_llm.runtime.Session.from_serialized_engine(engine)
|
|
session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
|
|
# pytorch run
|
|
out_ref, conv_state_ref, lru_state_ref = recurrent_torch(
|
|
hidden_states_ref, segment_pos, batch_size, remove_padding,
|
|
last_token_ids, conv_state_ref, lru_state_ref, conv_indices_ref)
|
|
|
|
dtype_atol = {"float16": 1e-2, "float32": 5e-3, "bfloat16": 5e-2}
|
|
|
|
# get mask
|
|
if not remove_padding and req_type == 'context':
|
|
out_mask = torch.zeros(batch_size, seq_len, device=device)
|
|
for i in range(batch_size):
|
|
for j in range(last_token_ids[i]):
|
|
out_mask[i, j] = 1
|
|
out_mask = out_mask.unsqueeze(2).expand(
|
|
[batch_size, seq_len, width])
|
|
else:
|
|
out_mask = torch.ones(size=hidden_states_shape, device=device)
|
|
|
|
# compare results
|
|
outputs['output'][out_mask == 0] = 0
|
|
output_trt_llm = outputs['output'].detach().to(torch.float32).cpu()
|
|
output_torch = (out_ref * out_mask).detach().to(torch.float32).cpu()
|
|
lru_state_trt_llm = outputs['present_lru_state'].detach().to(
|
|
torch.float32).cpu()
|
|
lru_state_torch = lru_state_ref.detach().to(torch.float32).cpu()
|
|
conv_state_ref = conv_state_ref[:, :, 1:].detach().to(
|
|
torch.float32).cpu().numpy()
|
|
conv_state_trt_llm = outputs['present_conv_state']
|
|
if use_plugin:
|
|
conv_state_trt_llm = conv_state_trt_llm.permute(0, 2, 1)
|
|
conv_state_trt_llm = conv_state_trt_llm.to(torch.float32).cpu().numpy()
|
|
|
|
atol = dtype_atol[dtype]
|
|
rtol = 1e-7
|
|
|
|
np.testing.assert_allclose(output_torch.numpy(),
|
|
output_trt_llm.numpy(),
|
|
atol=atol,
|
|
rtol=rtol)
|
|
np.testing.assert_allclose(lru_state_torch.numpy(),
|
|
lru_state_trt_llm.numpy(),
|
|
atol=atol,
|
|
rtol=rtol)
|
|
np.testing.assert_allclose(conv_state_ref,
|
|
conv_state_trt_llm,
|
|
atol=atol,
|
|
rtol=rtol)
|
|
|
|
def test_gemma3_local_attention_rope_scaling(self):
|
|
"""
|
|
Test that local attention layers in Gemma3 do NOT apply rope scaling,
|
|
even when the config has rope_scaling defined.
|
|
|
|
This is important for Gemma3 which uses different RoPE parameters for
|
|
local (sliding window) attention vs global attention layers. The fix
|
|
ensures that local attention layers get scale=1.0 and scale_type=none,
|
|
while global layers get the configured scaling.
|
|
"""
|
|
from tensorrt_llm.functional import PositionEmbeddingType
|
|
from tensorrt_llm.layers.attention import Attention
|
|
|
|
# Create a mock config similar to Gemma3 27B with rope_scaling
|
|
class MockGemma3Config:
|
|
hidden_size = 5376
|
|
num_attention_heads = 32
|
|
head_size = 128
|
|
max_position_embeddings = 32768
|
|
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
|
|
# Use small rotary base values to avoid numerical instability in tests.
|
|
# Large bases (e.g. 1000000) get exponentiated, causing potential flakiness
|
|
# when comparing floating point results.
|
|
rotary_base = 100.0
|
|
rotary_scaling = {"factor": 8.0, "rope_type": "linear"}
|
|
rotary_pct = 1.0
|
|
# Local attention uses a different base frequency
|
|
rope_local_base_freq = 10.0
|
|
|
|
# Create a mock model class to receive registered parameters
|
|
class MockModelCls:
|
|
position_embedding_type = PositionEmbeddingType.rope_gpt_neox
|
|
|
|
@classmethod
|
|
def register_parameter(cls, name, param):
|
|
setattr(cls, name, param)
|
|
|
|
config = MockGemma3Config()
|
|
|
|
# Call the method that creates attention const params
|
|
Attention.create_attention_const_params(MockModelCls, config)
|
|
|
|
# Verify that global rope parameters are registered
|
|
self.assertTrue(hasattr(MockModelCls, 'embed_positions'),
|
|
"Global embed_positions should be registered")
|
|
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq'),
|
|
"Global rotary_inv_freq should be registered")
|
|
self.assertTrue(
|
|
hasattr(MockModelCls, 'embed_positions_for_gpt_attention'),
|
|
"Global embed_positions_for_gpt_attention should be registered")
|
|
|
|
# Verify that local rope parameters are registered (since rope_local_base_freq is set)
|
|
self.assertTrue(hasattr(MockModelCls, 'embed_positions_local'),
|
|
"Local embed_positions should be registered")
|
|
self.assertTrue(hasattr(MockModelCls, 'rotary_inv_freq_local'),
|
|
"Local rotary_inv_freq should be registered")
|
|
self.assertTrue(
|
|
hasattr(MockModelCls, 'embed_positions_for_gpt_attention_local'),
|
|
"Local embed_positions_for_gpt_attention should be registered")
|
|
|
|
# Get the parameter values
|
|
global_inv_freq = MockModelCls.rotary_inv_freq.raw_value
|
|
local_inv_freq = MockModelCls.rotary_inv_freq_local.raw_value
|
|
global_cos_sin = MockModelCls.embed_positions_for_gpt_attention.raw_value
|
|
local_cos_sin = MockModelCls.embed_positions_for_gpt_attention_local.raw_value
|
|
|
|
# The global and local inv_freq should be different because:
|
|
# 1. Global uses rope_scaling with factor=8.0 (linear scaling applies 1/8 to inv_freq)
|
|
# 2. Local uses scale=1.0 (no scaling)
|
|
self.assertFalse(
|
|
np.allclose(global_inv_freq, local_inv_freq),
|
|
"Global and local rotary_inv_freq should be different "
|
|
"(global has scaling, local does not)")
|
|
|
|
# The cos/sin embeddings should also be different
|
|
self.assertFalse(
|
|
np.allclose(global_cos_sin, local_cos_sin),
|
|
"Global and local embed_positions_for_gpt_attention should be different "
|
|
"(global has scaling, local does not)")
|
|
|
|
# Additional verification: Check that local inv_freq matches unscaled calculation
|
|
# For local attention with scale=1.0 and base=10:
|
|
# inv_freq = 1.0 / (10 ** (arange(0, dim, 2) / dim))
|
|
dim = config.head_size # rotary_embedding_dim = head_size * rotary_pct = 128
|
|
expected_local_inv_freq = 1.0 / (config.rope_local_base_freq
|
|
**(np.arange(0, dim, 2) / dim))
|
|
|
|
np.testing.assert_allclose(
|
|
local_inv_freq,
|
|
expected_local_inv_freq,
|
|
rtol=1e-5,
|
|
err_msg="Local rotary_inv_freq should be computed WITHOUT scaling")
|
|
|
|
# For global attention with linear scaling (factor=8.0):
|
|
# scale = 1.0 / 8.0 = 0.125
|
|
# inv_freq = 0.125 / (100 ** (arange(0, dim, 2) / dim))
|
|
expected_global_inv_freq = (1.0 / 8.0) / (config.rotary_base**
|
|
(np.arange(0, dim, 2) / dim))
|
|
|
|
np.testing.assert_allclose(
|
|
global_inv_freq,
|
|
expected_global_inv_freq,
|
|
rtol=1e-5,
|
|
err_msg=
|
|
"Global rotary_inv_freq should be computed WITH linear scaling")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|