mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
1180 lines
47 KiB
Python
1180 lines
47 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import contextlib
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import _utils
|
|
import numpy as np
|
|
import tensorrt as trt
|
|
import torch
|
|
from parameterized import parameterized
|
|
from transformers import GPT2Config
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Tensor
|
|
from tensorrt_llm._utils import torch_to_numpy
|
|
from tensorrt_llm.quantization import QuantMode
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import (create_session, run_session, skip_bf16_pre_ampere,
|
|
skip_pre_ampere, unittest_name_func)
|
|
|
|
|
|
class GPT2AttentionSmoothQuant(torch.nn.Module):
|
|
""" initially copied from transformers.models.gpt2.modeling_gpt2
|
|
with modifications to run "smoothquant" GEMMs (i.e. i8xi8->i32->fp16)
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
max_positions = config.max_position_embeddings
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.tril(
|
|
torch.ones((max_positions, max_positions),
|
|
dtype=torch.uint8)).view(1, 1, max_positions,
|
|
max_positions),
|
|
)
|
|
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
|
|
|
self.embed_dim = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
self.split_size = self.embed_dim
|
|
if self.head_dim * self.num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
self.scale_attn_weights = config.scale_attn_weights
|
|
|
|
# we can't register linear layer with pytorch in int32. Use functional
|
|
# and define registered buffers instead
|
|
self.register_buffer("c_attn_weight",
|
|
torch.empty((3 * self.embed_dim, self.embed_dim)))
|
|
self.register_buffer("c_attn_bias", torch.empty(3 * self.embed_dim))
|
|
self.register_buffer("c_proj_weight",
|
|
torch.empty((self.embed_dim, self.embed_dim)))
|
|
self.register_buffer("c_proj_bias", torch.empty(self.embed_dim))
|
|
|
|
self.attn_dropout = torch.nn.Dropout(config.attn_pdrop)
|
|
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
|
|
|
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
|
|
|
if self.scale_attn_weights:
|
|
attn_weights = attn_weights / (float(value.size(-1))**0.5)
|
|
|
|
# if only "normal" attention layer implements causal mask
|
|
query_length, key_length = query.size(-2), key.size(-2)
|
|
causal_mask = self.bias[:, :, key_length -
|
|
query_length:key_length, :key_length].bool()
|
|
mask_value = torch.finfo(attn_weights.dtype).min
|
|
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
|
|
attn_weights.device)
|
|
attn_weights = torch.where(causal_mask,
|
|
attn_weights.to(attn_weights.dtype),
|
|
mask_value)
|
|
|
|
if attention_mask is not None:
|
|
# Apply the attention mask
|
|
attn_weights = attn_weights + attention_mask
|
|
|
|
attn_weights = torch.nn.Softmax(dim=-1)(attn_weights)
|
|
attn_weights = attn_weights.type(value.dtype)
|
|
attn_weights = self.attn_dropout(attn_weights)
|
|
|
|
# Mask heads if we want to
|
|
if head_mask is not None:
|
|
attn_weights = attn_weights * head_mask
|
|
|
|
attn_output = torch.matmul(attn_weights, value)
|
|
|
|
return attn_output, attn_weights
|
|
|
|
def _split_heads(self, tensor, num_heads, attn_head_size):
|
|
"""
|
|
Splits hidden_size dim into attn_head_size and num_heads
|
|
"""
|
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
|
tensor = tensor.view(*new_shape)
|
|
return tensor.permute(0, 2, 1,
|
|
3) # (batch, head, seq_length, head_features)
|
|
|
|
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
|
"""
|
|
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
|
"""
|
|
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
|
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, )
|
|
return tensor.view(new_shape)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
dtype,
|
|
layer_past=None,
|
|
attention_mask=None,
|
|
head_mask=None,
|
|
encoder_hidden_states=None,
|
|
encoder_attention_mask=None,
|
|
use_cache=False,
|
|
output_attentions=False,
|
|
quant_mode=QuantMode(0),
|
|
c_attn_dyn_scaling_factor=None,
|
|
):
|
|
if not quant_mode.has_act_and_weight_quant():
|
|
raise ValueError("quant_mode has to have some quantization")
|
|
|
|
qkv = _utils.gt_matmul_smooth_quant(hidden_states, self.c_attn_weight,
|
|
self.scale_attn_out,
|
|
self.scale_attn_w, dtype)
|
|
qkv = (qkv + self.c_attn_bias).to(
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
|
|
query, key, value = qkv.split(self.split_size, dim=2)
|
|
|
|
query = self._split_heads(query, self.num_heads, self.head_dim)
|
|
key = self._split_heads(key, self.num_heads, self.head_dim)
|
|
value = self._split_heads(value, self.num_heads, self.head_dim)
|
|
|
|
if layer_past is not None:
|
|
past_key, past_value = layer_past
|
|
key = torch.cat((past_key, key), dim=-2)
|
|
value = torch.cat((past_value, value), dim=-2)
|
|
|
|
if use_cache is True:
|
|
present = (key, value)
|
|
else:
|
|
present = None
|
|
|
|
attn_output, attn_weights = self._attn(query, key, value,
|
|
attention_mask, head_mask)
|
|
|
|
def to_i8(x):
|
|
return x.round().clip(-128, 127).to(dtype=torch.int8)
|
|
|
|
attn_output = self._merge_heads(attn_output, self.num_heads,
|
|
self.head_dim)
|
|
|
|
scales = self.scale_proj_out
|
|
if quant_mode.has_act_static_scaling():
|
|
attn_output = to_i8(attn_output * self.scale_proj_in.cuda())
|
|
else:
|
|
attn_output, scales = _utils.gt_quantize_per_token(attn_output)
|
|
|
|
attn_output = _utils.gt_matmul_smooth_quant(attn_output,
|
|
self.c_proj_weight, scales,
|
|
self.scale_proj_w, dtype)
|
|
attn_output = (attn_output + self.c_proj_bias).to(
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
|
|
attn_output = self.resid_dropout(attn_output)
|
|
|
|
outputs = (attn_output, present)
|
|
if output_attentions:
|
|
outputs += (attn_weights, )
|
|
|
|
return outputs # a, present, (attentions)
|
|
|
|
|
|
class TestSmoothQuant(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
@parameterized.expand(
|
|
[('float16', False, False, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float16', False, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float16', True, False, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float16', True, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('bfloat16', False, False, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('bfloat16', False, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('bfloat16', True, False, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('bfloat16', True, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float32', True, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('int32', True, True, False,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float32', False, False, True,
|
|
tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
('float32', False, False, True,
|
|
tensorrt_llm.quantization.layers.SmoothQuantRowLinear)],
|
|
name_func=unittest_name_func)
|
|
@skip_pre_ampere # INT8 TC does not support pre-Ampere
|
|
def test_linear_smooth_quant(self, dtype, per_token_scaling,
|
|
per_channel_scaling, bias, linear_cls):
|
|
# test data
|
|
d_h = 32
|
|
ffn_h = 64
|
|
test_shape = [2, 3, 5, d_h]
|
|
|
|
# Init operands for multiplication in int8
|
|
x_data = torch.randint(-128,
|
|
128,
|
|
test_shape,
|
|
dtype=torch.int8,
|
|
device="cuda")
|
|
fc1 = torch.randint(-128,
|
|
128, (ffn_h, d_h),
|
|
dtype=torch.int8,
|
|
device="cuda")
|
|
|
|
bias_data = None
|
|
if bias:
|
|
bias_data = torch.randint(
|
|
-5,
|
|
5, (ffn_h, ),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device="cuda") * 0.1
|
|
|
|
m = test_shape[0] * test_shape[1] * test_shape[2]
|
|
test_shape[3]
|
|
c_1 = ffn_h
|
|
|
|
quant_mode = QuantMode.from_description(True, True, per_token_scaling,
|
|
per_channel_scaling)
|
|
|
|
def init_scales(n):
|
|
scale_a_shape = (m, 1) if per_token_scaling else (1, 1)
|
|
scale_a = torch.ones(
|
|
scale_a_shape, dtype=torch.float32, device="cuda") * 1e-2
|
|
scale_a *= torch.randint(1,
|
|
10,
|
|
scale_a_shape,
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
scale_b_shape = (1, n) if per_channel_scaling else (1, 1)
|
|
scale_b = torch.ones(
|
|
scale_b_shape, dtype=torch.float32, device="cuda") * 1e-2
|
|
scale_b *= torch.randint(1,
|
|
10,
|
|
scale_b_shape,
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
return scale_a, scale_b
|
|
|
|
scale_fc1_out, scale_fc1_w = init_scales(c_1)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
# Allow SQ plugin of dtype type
|
|
network.plugin_config.smooth_quant_gemm_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int8'))
|
|
|
|
args = {}
|
|
if linear_cls == tensorrt_llm.quantization.layers.SmoothQuantLinear:
|
|
args['gather_output'] = False
|
|
gm = linear_cls(d_h,
|
|
ffn_h,
|
|
bias=bias,
|
|
quant_mode=quant_mode,
|
|
**args)
|
|
|
|
# TensorRT-LLM's Linear uses Parameter class which as a 'value' setter
|
|
gm.weight.value = fc1.cpu().numpy()
|
|
gm.per_channel_scale.value = scale_fc1_w.cpu().numpy()
|
|
if bias:
|
|
gm.bias.value = bias_data.cpu().numpy()
|
|
# Set activation scaling factors if needed
|
|
if quant_mode.has_act_static_scaling():
|
|
gm.act_scale.value = scale_fc1_out.cpu().numpy()
|
|
|
|
input = x
|
|
# If we have dynamic scaling, Linear expects Tuple input:
|
|
# (quantized tensor, scales per token)
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
scale_dynamic = Tensor(
|
|
name='scale_dynamic',
|
|
shape=scale_fc1_out.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
|
|
|
|
input = (x, scale_dynamic)
|
|
|
|
output = gm.forward(input)
|
|
output.mark_output('output')
|
|
|
|
# trt run
|
|
session = create_session(builder,
|
|
network,
|
|
precision=dtype,
|
|
int8=True,
|
|
memory_pool_limit=33554432)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
inputs['scale_dynamic'] = scale_fc1_out
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = _utils.gt_matmul_smooth_quant(x_data, fc1, scale_fc1_out,
|
|
scale_fc1_w, dtype, bias_data)
|
|
|
|
# compare diff
|
|
torch.testing.assert_close(ref, outputs['output'])
|
|
|
|
@parameterized.expand(
|
|
[(tensorrt_llm.quantization.layers.SmoothQuantLinear),
|
|
(tensorrt_llm.quantization.layers.SmoothQuantRowLinear)],
|
|
name_func=unittest_name_func)
|
|
def test_linear_smooth_quant_no_quant(self, linear_cls):
|
|
# Weight only quant for SmoothQuant
|
|
quant_mode = QuantMode.from_description(quantize_weights=True,
|
|
quantize_activations=False,
|
|
per_token=False,
|
|
per_channel=False)
|
|
|
|
args = {}
|
|
if linear_cls == tensorrt_llm.quantization.layers.SmoothQuantLinear:
|
|
args['gather_output'] = False
|
|
|
|
# Create builder
|
|
builder = tensorrt_llm.Builder()
|
|
# Create empty network
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
# Get output tensor for SQ Linear
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"SmoothQuant Linear has to have act\+weight quantization mode set"
|
|
):
|
|
linear_cls(32, 64, bias=False, quant_mode=quant_mode, **args)
|
|
|
|
@parameterized.expand([('float16', False, False, 'gelu'),
|
|
('float16', False, True, 'gelu'),
|
|
('float16', True, False, 'gelu'),
|
|
('float16', True, True, 'gelu'),
|
|
('bfloat16', False, False, 'gelu'),
|
|
('bfloat16', False, True, 'gelu'),
|
|
('bfloat16', True, False, 'gelu'),
|
|
('bfloat16', True, True, 'gelu'),
|
|
('float32', True, True, 'gelu'),
|
|
('float32', True, True, 'elu')],
|
|
name_func=unittest_name_func)
|
|
@skip_pre_ampere # INT8 TC is not supported in pre-Ampere
|
|
def test_mlp_smooth_quant(self, dtype, per_token_scaling,
|
|
per_channel_scaling, hidden_act):
|
|
# test data
|
|
d_h = 16
|
|
ffn_h = 32
|
|
test_shape = [2, 3, 5, d_h]
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Init operands for multiplication in int8
|
|
x_data = torch.randint(-8,
|
|
8,
|
|
test_shape,
|
|
dtype=torch.int8,
|
|
device="cuda")
|
|
fc1 = torch.randint(-16,
|
|
16, (ffn_h, d_h),
|
|
dtype=torch.int8,
|
|
device="cuda")
|
|
fc2 = torch.randint(-16,
|
|
16, (d_h, ffn_h),
|
|
dtype=torch.int8,
|
|
device="cuda")
|
|
|
|
m = test_shape[0] * test_shape[1] * test_shape[2]
|
|
c_1 = ffn_h
|
|
c_2 = d_h
|
|
|
|
quant_mode = QuantMode.from_description(True, True, per_token_scaling,
|
|
per_channel_scaling)
|
|
|
|
def init_scales(n):
|
|
scale_a_shape = (m, 1) if per_token_scaling else (1, 1)
|
|
scale_a = torch.ones(
|
|
scale_a_shape, dtype=torch.float32, device="cuda") * 1e-2
|
|
scale_a *= torch.randint(1,
|
|
10,
|
|
scale_a_shape,
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
scale_b_shape = (1, n) if per_channel_scaling else (1, 1)
|
|
scale_b = torch.ones(
|
|
scale_b_shape, dtype=torch.float32, device="cuda") * 1e-2
|
|
scale_b *= torch.randint(1,
|
|
10,
|
|
scale_b_shape,
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
return scale_a, scale_b
|
|
|
|
scale_fc1_out, scale_fc1_w = init_scales(c_1)
|
|
scale_fc2_out, scale_fc2_w = init_scales(c_2)
|
|
scale_fc2_in = torch.randint(
|
|
3, 7, (1, ), dtype=torch.float32, device="cuda") * 0.1
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
builder.strongly_typed = False # Test need to run in weekly typed mode
|
|
network = builder.create_network()
|
|
# Allow SQ plugin of dtype type
|
|
network.plugin_config.smooth_quant_gemm_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int8'))
|
|
|
|
if hidden_act == 'elu':
|
|
context = self.assertRaisesRegex(
|
|
ValueError, "unsupported activation function: *")
|
|
else:
|
|
context = contextlib.nullcontext()
|
|
|
|
with context:
|
|
gm = tensorrt_llm.quantization.layers.SmoothQuantMLP(
|
|
d_h,
|
|
ffn_h,
|
|
hidden_act=hidden_act,
|
|
bias=False,
|
|
quant_mode=quant_mode)
|
|
|
|
if hidden_act != 'gelu':
|
|
return
|
|
|
|
# TensorRT-LLM's MLP uses Parameter class which as a 'value' setter
|
|
gm.fc.weight.value = fc1.cpu().numpy()
|
|
gm.fc.per_channel_scale.value = scale_fc1_w.cpu().numpy()
|
|
gm.proj.weight.value = fc2.cpu().numpy()
|
|
gm.proj.per_channel_scale.value = scale_fc2_w.cpu().numpy()
|
|
gm.proj.smoother.value = np.ones([1, fc2.shape[1]],
|
|
dtype=np.float32)
|
|
|
|
# Set activation scaling factors if needed
|
|
if quant_mode.has_act_static_scaling():
|
|
gm.quantization_scaling_factor.value = scale_fc2_in.cpu().numpy(
|
|
)
|
|
gm.fc.act_scale.value = scale_fc1_out.cpu().numpy()
|
|
gm.proj.act_scale.value = scale_fc2_out.cpu().numpy()
|
|
|
|
input = x
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
scale_dynamic = Tensor(
|
|
name='scale_dynamic',
|
|
shape=scale_fc1_out.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
|
|
|
|
input = (x, scale_dynamic)
|
|
|
|
output = gm.forward(input)
|
|
output.mark_output("output")
|
|
|
|
# trt run
|
|
session = create_session(builder,
|
|
network,
|
|
precision=dtype,
|
|
int8=True,
|
|
memory_pool_limit=33554432)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
inputs['scale_dynamic'] = scale_fc1_out
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
gelu = torch.nn.GELU()
|
|
# FC 1
|
|
hidden = _utils.gt_matmul_smooth_quant(x_data, fc1, scale_fc1_out,
|
|
scale_fc1_w, dtype)
|
|
# ACT
|
|
hidden = gelu(hidden)
|
|
|
|
# Dynamic/static quantization
|
|
scale_act = scale_fc2_out
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
hidden, scale_act = _utils.gt_quantize_per_token(hidden)
|
|
else:
|
|
hidden = (hidden * scale_fc2_in).round().clip(
|
|
-128, 127).to(dtype=torch.int8)
|
|
|
|
# FC 2
|
|
ref = _utils.gt_matmul_smooth_quant(hidden, fc2, scale_act,
|
|
scale_fc2_w, dtype)
|
|
|
|
# compare diff
|
|
torch.testing.assert_close(ref, outputs['output'], atol=6.25e-2, rtol=0)
|
|
|
|
@parameterized.expand([('float16', True, True), ('float16', True, False),
|
|
('bfloat16', True, True)],
|
|
name_func=unittest_name_func)
|
|
def test_smooth_quant_layer_norm_layer(self, dtype, per_token_scaling,
|
|
elementwise_affine):
|
|
# Skip tests that are not supported in pre-ampere architecture
|
|
skip_bf16_pre_ampere(dtype)
|
|
|
|
torch.manual_seed(1997)
|
|
# test data
|
|
hidden_size = 1024
|
|
x_data = torch.randn(
|
|
(8, 128, hidden_size),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device="cuda")
|
|
eps = 1e-5
|
|
|
|
m = torch.nn.LayerNorm(
|
|
hidden_size,
|
|
eps=eps,
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
elementwise_affine=elementwise_affine,
|
|
device="cuda")
|
|
|
|
# Scale to int
|
|
scale_data = torch.randint(2,
|
|
32, (1, ),
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
scale_to_int_data = torch.ones((1, ),
|
|
dtype=torch.float32,
|
|
device="cuda")
|
|
|
|
quant_mode = QuantMode.from_description(quantize_weights=True,
|
|
quantize_activations=True,
|
|
per_token=per_token_scaling,
|
|
per_channel=False)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
network.plugin_config.layernorm_quantization_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
ln = tensorrt_llm.quantization.layers.SmoothQuantLayerNorm(
|
|
hidden_size,
|
|
quant_mode=quant_mode,
|
|
elementwise_affine=elementwise_affine,
|
|
dtype=dtype)
|
|
|
|
ln.scale_to_int.value = scale_to_int_data.detach().cpu().numpy()
|
|
|
|
if elementwise_affine:
|
|
gamma_data = m.weight.detach().cpu()
|
|
beta_data = m.bias.detach().cpu()
|
|
ln.weight.value = torch_to_numpy(gamma_data)
|
|
ln.bias.value = torch_to_numpy(beta_data)
|
|
|
|
output = ln.forward(x)
|
|
|
|
if per_token_scaling:
|
|
output, dynamic_scales = output
|
|
dynamic_scales.mark_output('dynamic_scales', trt.float32)
|
|
output.mark_output('output', trt.int8)
|
|
|
|
# trt run
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
def cast_to_int8_with_sat(tensor):
|
|
return tensor.round().clip(-128, 127).to(dtype=torch.int8)
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = m(x_data).to(dtype=torch.float32)
|
|
if per_token_scaling:
|
|
abs_max_f, _ = ref.abs().max(dim=-1, keepdim=True)
|
|
dynamic_scale = abs_max_f / 127.0
|
|
ref_quantized = cast_to_int8_with_sat(ref * (127.0 / abs_max_f))
|
|
else:
|
|
ref_quantized = cast_to_int8_with_sat(ref * scale_data)
|
|
|
|
# compare diff of quantized output
|
|
# Set absolute tolerance to 1 to mitigate some rounding error
|
|
torch.testing.assert_close(ref_quantized,
|
|
outputs['output'],
|
|
atol=1,
|
|
rtol=0)
|
|
|
|
# compare diff of dynamic activation scales
|
|
if per_token_scaling:
|
|
torch.testing.assert_close(dynamic_scale,
|
|
outputs['dynamic_scales'],
|
|
atol=1e-2,
|
|
rtol=1e-2)
|
|
|
|
@parameterized.expand(
|
|
[('float16', 1, False,
|
|
tensorrt_llm.quantization.layers.WeightOnlyQuantLinear),
|
|
('float16', 2, False,
|
|
tensorrt_llm.quantization.layers.WeightOnlyQuantLinear),
|
|
('float16', 1, True,
|
|
tensorrt_llm.quantization.layers.WeightOnlyQuantLinear),
|
|
('float16', 1, True,
|
|
tensorrt_llm.quantization.layers.WeightOnlyQuantRowLinear)],
|
|
name_func=unittest_name_func)
|
|
@skip_pre_ampere # WOQ contains bug in pre-Ampere
|
|
def test_linear_weight_only_linear(self, dtype, wTypeId, bias, linear_cls):
|
|
# test data
|
|
m = 1
|
|
n = 1024
|
|
k = 4096
|
|
|
|
# Init operands for multiplication in int32
|
|
mat1 = _utils.woq_gen_weights(m, k, dtype) * 200.0
|
|
weight = _utils.woq_gen_weights(k, n, dtype)
|
|
|
|
ref_torch_weights, processed_torch_weights, torch_weight_scales = _utils.woq_conversion(
|
|
weight, wTypeId)
|
|
if wTypeId == 2:
|
|
# Weights must be a CPU Tensor
|
|
ref_torch_weights = torch.ops.trtllm.unpack_int4_packed_tensor_to_int8(
|
|
ref_torch_weights.cpu())
|
|
|
|
bias_data = None
|
|
if bias:
|
|
bias_data = torch.randint(
|
|
-5,
|
|
5, (n, ),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device="cuda") * 0.1
|
|
|
|
quant_mode = QuantMode.from_description(quantize_weights=True,
|
|
quantize_activations=False,
|
|
per_token=False,
|
|
per_channel=False,
|
|
use_int4_weights=(wTypeId == 2))
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
network.plugin_config.weight_only_quant_matmul_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=mat1.shape,
|
|
dtype=tensorrt_llm._utils.str_dtype_to_trt(dtype))
|
|
|
|
args = {}
|
|
if linear_cls == tensorrt_llm.quantization.layers.WeightOnlyQuantLinear:
|
|
args['gather_output'] = False
|
|
gm = linear_cls(k, n, bias=bias, quant_mode=quant_mode, **args)
|
|
|
|
# TensorRT-LLM's Linear uses Parameter class which as a 'value' setter
|
|
gm.weight.value = processed_torch_weights.cpu().numpy()
|
|
gm.per_channel_scale.value = torch_weight_scales.cpu().numpy()
|
|
if bias:
|
|
gm.bias.value = bias_data.cpu().numpy()
|
|
|
|
input = x
|
|
|
|
output = gm.forward(input)
|
|
output.mark_output('output')
|
|
|
|
# trt run
|
|
session = create_session(builder,
|
|
network,
|
|
precision=dtype,
|
|
int8=True,
|
|
memory_pool_limit=33554432)
|
|
inputs = {
|
|
'x': mat1,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
# pytorch run
|
|
with torch.no_grad():
|
|
ref = _utils.woq_gt_matmul(m, mat1, ref_torch_weights.cuda(),
|
|
torch_weight_scales.cuda(), dtype,
|
|
bias_data)
|
|
|
|
# compare diff
|
|
_utils.woq_assert_near_eq(ref, outputs['output'], wTypeId)
|
|
|
|
@parameterized.expand([('float16', QuantMode.PER_CHANNEL),
|
|
('float16', QuantMode.PER_TOKEN),
|
|
('float16', QuantMode.PER_GROUP),
|
|
('bfloat16', QuantMode.PER_CHANNEL),
|
|
('bfloat16', QuantMode.PER_TOKEN),
|
|
('bfloat16', QuantMode.PER_GROUP)],
|
|
name_func=unittest_name_func)
|
|
@unittest.skip("Attention contains a bug and will be resolved in later MRs")
|
|
def test_gpt_attention_smoothquant(self,
|
|
dtype="float16",
|
|
quant_mode=QuantMode.from_description(
|
|
True, True, False, False)):
|
|
|
|
def _construct_execution():
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
network.plugin_config.smooth_quant_gemm_plugin = dtype
|
|
network.plugin_config.gpt_attention_plugin = dtype
|
|
with tensorrt_llm.net_guard(network):
|
|
hidden_states_tensor = Tensor(
|
|
name='hidden_states',
|
|
shape=tuple(input.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int8'))
|
|
|
|
input_tensor = hidden_states_tensor
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
scale_dynamic_tensor = Tensor(
|
|
name='scale_dynamic',
|
|
shape=tuple(scale_attn_out.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt('float32'))
|
|
|
|
input_tensor = (hidden_states_tensor, scale_dynamic_tensor)
|
|
|
|
past_key_value = None
|
|
if use_past_key_value or use_gpt_attention_plugin:
|
|
past_key_tensor = Tensor(
|
|
name='past_key',
|
|
shape=tuple(present_key.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
past_value_tensor = Tensor(
|
|
name='past_value',
|
|
shape=tuple(present_value.shape),
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
|
|
past_key_value = (past_key_tensor, past_value_tensor)
|
|
|
|
sequence_length_tensor = None
|
|
past_key_value_length_tensor = None
|
|
input_lengths_tensor = None
|
|
cache_indirection_tensor = None
|
|
if use_gpt_attention_plugin:
|
|
sequence_length_tensor = Tensor(name='sequence_length',
|
|
dtype=trt.int32,
|
|
shape=tuple(
|
|
sequence_length.shape))
|
|
past_key_value_length_tensor = Tensor(
|
|
name='past_key_value_length',
|
|
dtype=trt.int32,
|
|
shape=tuple(past_key_value_length.shape))
|
|
|
|
input_lengths_tensor = Tensor(name='input_lengths',
|
|
dtype=trt.int32,
|
|
shape=tuple(
|
|
input_lengths.shape))
|
|
|
|
cache_indirection_tensor = Tensor(
|
|
name='cache_indirection',
|
|
dtype=trt.int32,
|
|
shape=tuple(cache_indirection.shape))
|
|
|
|
attention = tensorrt_llm_gpt
|
|
|
|
attention.qkv.weight.value = weight_qkv.cpu().numpy()
|
|
attention.qkv.bias.value = bias_qkv.cpu().numpy()
|
|
attention.qkv.per_channel_scale.value = scale_attn_w.cpu(
|
|
).numpy()
|
|
|
|
attention.dense.weight.value = weight_proj.cpu().numpy()
|
|
attention.dense.bias.value = bias_proj.cpu().numpy()
|
|
attention.dense.per_channel_scale.value = scale_proj_w.cpu(
|
|
).numpy()
|
|
attention.dense.smoother.value = np.ones(
|
|
[1, weight_proj.shape[1] * 4], dtype=np.float32)
|
|
|
|
# Set activation scaling factors if needed
|
|
if quant_mode.has_act_static_scaling():
|
|
attention.quantization_scaling_factor.value = scale_proj_in.cpu(
|
|
).numpy()
|
|
attention.qkv.act_scale.value = scale_attn_out.cpu().numpy()
|
|
attention.dense.act_scale.value = scale_proj_out.cpu(
|
|
).numpy()
|
|
|
|
outputs = attention(
|
|
input_tensor,
|
|
attention_mask=None,
|
|
past_key_value=past_key_value,
|
|
sequence_length=sequence_length_tensor,
|
|
past_key_value_length=past_key_value_length_tensor,
|
|
use_cache=True,
|
|
input_lengths=input_lengths_tensor,
|
|
cache_indirection=cache_indirection_tensor)
|
|
|
|
outputs[0].mark_output('output', dtype)
|
|
outputs[1][0].mark_output('present_key', dtype)
|
|
output[1][1].mark_output('present_value', dtype)
|
|
|
|
# trt build engine
|
|
session = create_session(builder,
|
|
network,
|
|
precision=dtype,
|
|
int8=True,
|
|
memory_pool_limit=48 * (2**30))
|
|
inputs = {
|
|
'hidden_states': input,
|
|
'cache_indirection': cache_indirection
|
|
}
|
|
|
|
if use_past_key_value or use_gpt_attention_plugin:
|
|
inputs['past_key'] = present_key
|
|
inputs['past_value'] = present_value
|
|
|
|
if use_gpt_attention_plugin:
|
|
inputs['sequence_length'] = sequence_length
|
|
inputs['past_key_value_length'] = past_key_value_length
|
|
inputs['input_lengths'] = input_lengths
|
|
|
|
if quant_mode.has_per_token_dynamic_scaling():
|
|
inputs['scale_dynamic'] = scale_attn_out
|
|
|
|
outputs = run_session(session, inputs)
|
|
return outputs
|
|
|
|
batch_size = 4
|
|
in_len = 128
|
|
out_len = 8
|
|
max_seq_len = 148
|
|
hidden_size = 1024
|
|
num_heads = 16
|
|
head_size = hidden_size // num_heads
|
|
shape_dict = {
|
|
'weight': (hidden_size * 3, hidden_size),
|
|
'bias': (hidden_size * 3, ),
|
|
'past_key': (batch_size, num_heads, max_seq_len, head_size),
|
|
'past_value': (batch_size, num_heads, max_seq_len, head_size),
|
|
'present_key': (batch_size, num_heads, max_seq_len, head_size),
|
|
'present_value': (batch_size, num_heads, max_seq_len, head_size),
|
|
'sequence_length': (batch_size, ),
|
|
'input_lengths': (batch_size, ),
|
|
}
|
|
|
|
weight_qkv = torch.randint(-10,
|
|
10,
|
|
shape_dict['weight'],
|
|
dtype=torch.int8)
|
|
bias_qkv = torch.randn(
|
|
shape_dict['bias'],
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda') * 1e-1
|
|
|
|
weight_proj = torch.eye(hidden_size, dtype=torch.int8)
|
|
bias_proj = torch.zeros(
|
|
(hidden_size, ),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda')
|
|
|
|
input_lengths = torch.ones(
|
|
(batch_size, ), dtype=torch.int32, device='cuda') * in_len
|
|
cache_indirection = torch.full((
|
|
batch_size,
|
|
1,
|
|
max_seq_len,
|
|
),
|
|
0,
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
|
|
present_key = torch.zeros(
|
|
shape_dict['present_key'],
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda')
|
|
present_value = torch.zeros(
|
|
shape_dict['present_value'],
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device='cuda')
|
|
torch_present = None
|
|
|
|
per_token_scaling = quant_mode.has_per_token_dynamic_scaling()
|
|
per_channel_scaling = quant_mode.has_per_channel_scaling()
|
|
|
|
def init_scales(m, n, token_scaling, channel_scaling):
|
|
scale_a_shape = (m, 1) if token_scaling else (1, 1)
|
|
scale_a = torch.ones(scale_a_shape, dtype=torch.float32) * 1e-2
|
|
scale_a *= torch.randint(1, 10, scale_a_shape, dtype=torch.float32)
|
|
scale_b_shape = (1, n) if channel_scaling else (1, 1)
|
|
scale_b = torch.ones(scale_b_shape, dtype=torch.float32) * 1e-2
|
|
scale_b *= torch.randint(1, 10, scale_b_shape, dtype=torch.float32)
|
|
return scale_a, scale_b
|
|
|
|
# We always do per channel scaling for QKV
|
|
scale_attn_out, scale_attn_w = init_scales(batch_size * in_len,
|
|
3 * hidden_size,
|
|
per_token_scaling, True)
|
|
scale_proj_out, scale_proj_w = init_scales(batch_size * in_len,
|
|
hidden_size,
|
|
per_token_scaling,
|
|
per_channel_scaling)
|
|
scale_proj_in = torch.randint(3, 7, (1, ), dtype=torch.float32) * 0.1
|
|
|
|
# instantiate pytorch equivalent of attention SQ
|
|
configuration = GPT2Config(
|
|
hidden_size=hidden_size,
|
|
n_layer=1,
|
|
n_head=num_heads,
|
|
vocab_size=51200,
|
|
use_cache=True,
|
|
resid_pdrop=0,
|
|
embd_pdrop=0,
|
|
attn_pdrop=0,
|
|
hidden_act='gelu',
|
|
torch_dtype=dtype,
|
|
)
|
|
n_positions = configuration.n_positions
|
|
|
|
gt_attention = GPT2AttentionSmoothQuant(configuration).cuda().eval()
|
|
gt_attention.c_attn_weight = torch.nn.parameter.Parameter(
|
|
data=weight_qkv.clone().detach(), requires_grad=False)
|
|
gt_attention.c_attn_bias = torch.nn.parameter.Parameter(
|
|
data=bias_qkv.clone().detach(), requires_grad=False)
|
|
gt_attention.c_proj_weight = torch.nn.parameter.Parameter(
|
|
data=weight_proj, requires_grad=False)
|
|
gt_attention.c_proj_bias = torch.nn.parameter.Parameter(
|
|
data=bias_proj, requires_grad=False)
|
|
gt_attention.scale_attn_out, gt_attention.scale_proj_out = scale_attn_out, scale_proj_out
|
|
gt_attention.scale_attn_w, gt_attention.scale_proj_w = scale_attn_w, scale_proj_w
|
|
gt_attention.scale_proj_in = scale_proj_in
|
|
|
|
# instantiate full gpt model before isolating its attention module
|
|
tensorrt_llm_gpt = tensorrt_llm.quantization.layers.SmoothQuantAttention(
|
|
layer_idx=0,
|
|
hidden_size=hidden_size,
|
|
num_attention_heads=num_heads,
|
|
max_position_embeddings=n_positions,
|
|
num_layers=1,
|
|
attention_mask_type=tensorrt_llm.layers.AttentionMaskType.causal,
|
|
dtype=dtype,
|
|
quant_mode=quant_mode)
|
|
|
|
for step in range(out_len):
|
|
sequence_length = torch.ones(
|
|
(batch_size, ), dtype=torch.int32,
|
|
device='cuda') * (in_len + step)
|
|
if step == 0:
|
|
# Context stage
|
|
shape_dict['hidden_states'] = (batch_size, in_len, hidden_size)
|
|
shape_dict['output'] = shape_dict['hidden_states']
|
|
past_key_value_length = torch.tensor([step], dtype=torch.int32)
|
|
|
|
input = torch.randint(-16,
|
|
16,
|
|
shape_dict['hidden_states'],
|
|
dtype=torch.int8)
|
|
|
|
# torch execution
|
|
torch_output, torch_present = gt_attention(
|
|
input,
|
|
dtype,
|
|
layer_past=None,
|
|
use_cache=True,
|
|
quant_mode=quant_mode)
|
|
|
|
use_past_key_value = False
|
|
use_gpt_attention_plugin = True
|
|
outputs = _construct_execution()
|
|
output = outputs["output"]
|
|
present_key = outputs["present_key"]
|
|
present_value = outputs["present_value"]
|
|
|
|
print(output, torch_output)
|
|
|
|
torch.testing.assert_close(output, torch_output, atol=1e-2)
|
|
|
|
else:
|
|
# Generation stage
|
|
shape_dict['hidden_states'] = (batch_size, 1, hidden_size)
|
|
shape_dict['output'] = shape_dict['hidden_states']
|
|
past_key_value_length = torch.tensor([in_len + step - 1],
|
|
dtype=torch.int32)
|
|
input = torch.randint(-16,
|
|
16,
|
|
shape_dict['hidden_states'],
|
|
dtype=torch.int8)
|
|
|
|
# torch execution
|
|
torch_output, torch_present = gt_attention(
|
|
input,
|
|
dtype,
|
|
layer_past=torch_present,
|
|
use_cache=True,
|
|
quant_mode=quant_mode)
|
|
|
|
use_past_key_value = True
|
|
use_gpt_attention_plugin = True
|
|
outputs = _construct_execution()
|
|
output = outputs["output"]
|
|
present_key = outputs["present_key"]
|
|
present_value = outputs["present_value"]
|
|
|
|
print(output, torch_output)
|
|
|
|
torch.testing.assert_close(output, torch_output, atol=1e-2)
|
|
|
|
def test_quantize_per_tensor(self):
|
|
dtype = 'float32'
|
|
x_data = torch.randn((1, 2, 2, 4), dtype=torch.float32, device="cuda")
|
|
scaling_factor_data = torch.tensor(0.4, dtype=torch.float32)
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
|
|
with tensorrt_llm.net_guard(network):
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
q_layer = tensorrt_llm.quantization.layers.Quantize('int8')
|
|
q_layer.scaling_factor.value = scaling_factor_data.numpy()
|
|
output = q_layer.forward(x)
|
|
output.mark_output('output', trt.int8)
|
|
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
ref = torch.quantize_per_tensor(x_data, scaling_factor_data.cuda(), 0,
|
|
torch.qint8)
|
|
|
|
# Avoid comparing between is_quantized
|
|
torch.testing.assert_close(ref.int_repr(), outputs['output'])
|
|
|
|
def test_quantize_per_channel(self):
|
|
dtype = 'float32'
|
|
x_data = torch.randn((2, 4, 4, 8), dtype=torch.float32, device="cuda")
|
|
scaling_factor_data = torch.tensor((0.4, 0.1, 0.3, 0.2),
|
|
dtype=torch.float32)
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
axis = 1
|
|
q_layer = tensorrt_llm.quantization.layers.Quantize(
|
|
'int8', 'float32', x_data.shape[axis], axis)
|
|
q_layer.scaling_factor.value = scaling_factor_data.detach().cpu(
|
|
).numpy()
|
|
output = q_layer.forward(x)
|
|
output.mark_output('output')
|
|
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
ref = torch.quantize_per_channel(
|
|
x_data, scaling_factor_data.cuda(),
|
|
torch.tensor([0, 0, 0, 0], device="cuda"), 1, torch.qint8)
|
|
# Avoid comparing between is_quantized
|
|
torch.testing.assert_close(ref.int_repr(), outputs['output'])
|
|
|
|
@parameterized.expand([('float16'), ('bfloat16'), ('float32')],
|
|
name_func=unittest_name_func)
|
|
def test_quantize_per_token(self, dtype):
|
|
# Skip tests that are not supported in pre-ampere architecture
|
|
skip_bf16_pre_ampere(dtype)
|
|
|
|
x_data = torch.randn(
|
|
(2, 4, 4, 8),
|
|
dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype),
|
|
device="cuda")
|
|
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
q_layer = tensorrt_llm.quantization.layers.QuantizePerToken()
|
|
output, scale = q_layer.forward(x)
|
|
|
|
output.mark_output('output')
|
|
scale.mark_output('scale')
|
|
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
ref, ref_scale = _utils.gt_quantize_per_token(x_data)
|
|
ref = ref.reshape(outputs['output'].shape)
|
|
ref_scale = ref_scale.reshape(outputs['scale'].shape)
|
|
|
|
torch.testing.assert_close(ref, outputs['output'], atol=1, rtol=1e-1)
|
|
|
|
torch.testing.assert_close(ref_scale.float(),
|
|
outputs['scale'].float(),
|
|
atol=1e-2,
|
|
rtol=1e-1)
|
|
|
|
def test_dequantize(self):
|
|
dtype = 'float32'
|
|
x_data = torch.quantize_per_tensor(
|
|
torch.tensor([-1.0, 0.0, 1.0, 2.0],
|
|
dtype=torch.float32,
|
|
device="cuda"), 0.1, 0, torch.qint8)
|
|
scaling_factor_data = torch.tensor(0.1, dtype=torch.float32)
|
|
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
x = Tensor(name='x',
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int8'))
|
|
dq_layer = tensorrt_llm.quantization.layers.Dequantize()
|
|
|
|
dq_layer.scaling_factor.value = scaling_factor_data.numpy()
|
|
output = dq_layer.forward(x)
|
|
output.mark_output('output', dtype)
|
|
|
|
session = create_session(builder, network, precision=dtype, int8=True)
|
|
inputs = {
|
|
'x': x_data,
|
|
}
|
|
|
|
outputs = run_session(session, inputs)
|
|
|
|
ref = torch.dequantize(x_data)
|
|
|
|
torch.testing.assert_close(ref, outputs['output'])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|