TensorRT-LLMs/tests/functional/test_moe.py
Kaiyu Xie 4bb65f216f
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2024-03-12 18:15:52 +08:00

659 lines
27 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import unittest
import numpy as np
# isort: off
import torch
import tensorrt as trt
# isort: on
import os
import sys
from parameterized import parameterized
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
import tensorrt_llm
from tensorrt_llm import Tensor
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
from tensorrt_llm.layers.moe import MoeConfig
from tensorrt_llm.quantization import QuantMode
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion, skip_bf16_pre_ampere, unittest_name_func
default_actfn = 'gelu'
def make_tuple(num_experts=4,
topk=1,
hidden_size=8,
num_sequences=5,
sequence_length=4,
actfn=default_actfn,
bias=True,
dtype='float32',
weight_dtype=None,
norm_mode=MoeConfig.ExpertScaleNormalizationMode.NONE,
use_plugin=True):
if weight_dtype is None:
weight_dtype = dtype
return (num_experts, topk, hidden_size, num_sequences, sequence_length,
actfn, bias, dtype, weight_dtype, norm_mode, use_plugin)
def gen_uniform_weights(*args, **kwargs):
return (torch.rand(*args, **kwargs) * 2 - 1).contiguous()
def quant_dequant(weights, quant_mode):
if not quant_mode.is_weight_only():
return weights
# use the test version `_symmetric_...` to get the non-interleaved weights
type = torch.quint4x2 if quant_mode.is_int4_weight_only() else torch.int8
quant_weights, _, torch_weight_scales = torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(
weights.T.cpu().contiguous(), type)
# Unpack the int4s int int8s
if quant_mode.is_int4_weight_only():
upper = (quant_weights >> 4)
lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends
quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape)
quant_weights = quant_weights.to(dtype=weights.dtype)
result = torch.multiply(quant_weights,
torch_weight_scales.unsqueeze(0)).T.contiguous()
return result.to(device=weights.device)
GATED_TO_ACT = {
'swiglu': 'silu',
'geglu': 'gelu',
}
def is_gated_activation(actfn):
return actfn in GATED_TO_ACT
def gated2act(actfn):
if is_gated_activation(actfn):
return GATED_TO_ACT[actfn]
return actfn
def doact(input, actfn):
assert not is_gated_activation(actfn)
if actfn == 'gelu':
return torch.nn.functional.gelu(input)
if actfn == 'relu':
return torch.nn.functional.relu(input)
if actfn == 'silu':
return torch.nn.functional.silu(input)
assert actfn == "identity"
return input # Identity
def gated_matmul(input, weights, bias, actfn):
assert is_gated_activation(actfn)
fc1 = torch.matmul(input, weights.T) + bias
fc1, gate = fc1.chunk(2, dim=-1)
return fc1 * doact(gate, gated2act(actfn))
class TestFunctional(unittest.TestCase):
def setUp(self):
# There is a known precision issues where the topk may select different experts when the routing probabilities are similar.
# This causes a completely different output for the affected tokens. So we set the seed to prevent sporadic failures
# This shouldn't be a problem for most practical applications as it means the experts are equally good choices
torch.manual_seed(0x766E)
tensorrt_llm.logger.set_level('error')
def eye(self, shape, dtype, device='cuda'):
""" Utility function for creating expert weights as an identity matrix for easy debugging """
eye = torch.eye(shape[-2], m=shape[-1], dtype=dtype, device=device)
eye = eye.repeat(*shape[:-2], 1, 1)
return eye
@staticmethod
def get_params():
params = []
# Some default values to use for most test cases
for experts in [1, 4, 42, 1024]:
for topk in [1, 2, 3]:
if topk < experts:
params += [
make_tuple(num_experts=experts,
topk=topk,
dtype='float16'),
]
# OOTB test
for experts in [1, 4, 42]:
for topk in [1, 2, 3]:
if topk < experts:
# TODO: Support ootb path with getSMVersion() < 90:
if getSMVersion() >= 90:
params += [
make_tuple(num_experts=experts,
topk=topk,
dtype='float16',
use_plugin=False)
]
for num_tokens in [1, 42, 100]:
for sequence_length in [1, 3, 42]:
num_sequences = math.ceil(num_tokens / sequence_length)
params += [
make_tuple(num_sequences=num_sequences,
sequence_length=sequence_length,
dtype='float16'),
]
if getSMVersion() >= 90:
params += [
make_tuple(num_sequences=num_sequences,
sequence_length=sequence_length,
dtype='float16',
use_plugin=False)
]
# Add a test for float32
params += [
# Try 5 because non-power 2 use a different topk kernel
make_tuple(num_experts=5, dtype='float32'),
]
# TODO: Support ootb path with getSMVersion() < 90:
if getSMVersion() >= 90:
params += [
make_tuple(num_experts=5, dtype='float32', use_plugin=False)
]
# Add a test for bfloat16
if getSMVersion() >= 80:
params += [
make_tuple(dtype='bfloat16'),
# Try 5 because non-power 2 use a different topk kernel
make_tuple(num_experts=5, dtype='bfloat16')
]
# Add some cases for quantized dtype
for dtype in ('int8', 'int4'):
params += [
make_tuple(dtype='float16', hidden_size=64, weight_dtype=dtype),
make_tuple(dtype='float16',
hidden_size=64,
num_experts=5,
weight_dtype=dtype),
]
if getSMVersion() >= 80:
params += [
make_tuple(dtype='bfloat16',
hidden_size=64,
weight_dtype=dtype)
]
# Test all activation functions with float16
for actfn in ('relu', 'silu', 'gelu', 'swiglu', 'geglu', 'identity'):
if actfn == default_actfn:
continue # Dont need to retest the one every other case uses
params += [make_tuple(actfn=actfn, dtype='float16')]
# Test OOTB path with all activation function with float16
# TODO: Support ootb path with getSMVersion() < 90:
if getSMVersion() >= 90:
params += [
make_tuple(actfn=actfn, dtype='float16', use_plugin=False)
]
# Test gated with all data types as it has a different path
for actfn in ('swiglu', 'geglu'):
if actfn == default_actfn:
continue # Dont need to retest the one every other case uses
params += [
make_tuple(actfn=actfn, dtype='float32'),
make_tuple(actfn=actfn,
hidden_size=64,
dtype='float16',
weight_dtype='int8'),
]
if getSMVersion() >= 80:
params += [make_tuple(actfn=actfn, dtype='bfloat16')]
# Test different k values for gated activations
params += [
make_tuple(actfn='geglu', topk=2, dtype='float16'),
]
# TODO: Support ootb path with getSMVersion() < 90:
if getSMVersion() >= 90:
params += [
make_tuple(actfn='geglu', topk=2, bias=False, dtype='float16')
]
# Test no bias
params += [
make_tuple(bias=False, dtype='float32'),
make_tuple(bias=False, dtype='float16'),
make_tuple(dtype='float16',
hidden_size=64,
weight_dtype='int8',
bias=False),
make_tuple(dtype='float16',
hidden_size=64,
weight_dtype='int4',
bias=False)
]
# Test renormalization
params += [
make_tuple(
topk=2,
dtype='float32',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE),
make_tuple(
topk=2,
dtype='float16',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE),
make_tuple(
dtype='float16',
topk=2,
hidden_size=64,
weight_dtype='int8',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE),
make_tuple(
dtype='float16',
topk=2,
hidden_size=128,
weight_dtype='int4',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE),
# Renorm affects the final accumulate, so sanity check with no bias too
make_tuple(
norm_mode=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
topk=2,
dtype='float16',
bias=False),
]
# Test OOTB renormalization
# TODO: Support ootb path with getSMVersion() < 90:
if getSMVersion() >= 90:
params += [
make_tuple(topk=2,
dtype='float32',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.
RENORMALIZE,
use_plugin=False),
make_tuple(topk=2,
dtype='float16',
norm_mode=MoeConfig.ExpertScaleNormalizationMode.
RENORMALIZE,
use_plugin=False),
# Renorm affects the final accumulate, so sanity check with no bias too
make_tuple(norm_mode=MoeConfig.ExpertScaleNormalizationMode.
RENORMALIZE,
topk=2,
dtype='float16',
bias=False,
use_plugin=False),
]
if getSMVersion() >= 80:
params += [
make_tuple(dtype='bfloat16',
topk=2,
norm_mode=MoeConfig.ExpertScaleNormalizationMode.
RENORMALIZE)
]
if getSMVersion() >= 90:
# TODO: Support ootb path with getSMVersion() < 90:
params += [
make_tuple(dtype='bfloat16',
topk=2,
norm_mode=MoeConfig.ExpertScaleNormalizationMode.
RENORMALIZE,
use_plugin=False)
]
return params
def create_weights(self, num_experts, hidden_size, ffn_hidden_size, bias,
dtype, weight_dtype, is_gated):
self.router_weights = torch.randn((num_experts, hidden_size),
dtype=trt_dtype_to_torch(dtype),
device="cuda")
# Use a uniform scale for int8 so the quantization has a well-behaved dynamic range
genfn = gen_uniform_weights if weight_dtype == trt.int8 else torch.randn
fc1_out_size = ffn_hidden_size * 2 if is_gated else ffn_hidden_size
self.fc1_weights = genfn((num_experts, fc1_out_size, hidden_size),
dtype=trt_dtype_to_torch(dtype),
device="cuda")
self.fc2_weights = genfn((num_experts, hidden_size, ffn_hidden_size),
dtype=trt_dtype_to_torch(dtype),
device="cuda")
bias_tensor_func = genfn if bias else torch.zeros
self.fc1_bias = bias_tensor_func((num_experts, fc1_out_size),
dtype=trt_dtype_to_torch(dtype),
device="cuda")
self.fc2_bias = bias_tensor_func((num_experts, hidden_size),
dtype=trt_dtype_to_torch(dtype),
device="cuda")
@parameterized.expand(get_params(), name_func=unittest_name_func)
def test_mixture_of_experts(self, num_experts, top_k, hidden_size,
num_sequences, sequence_lengths, actfn, bias,
dtype_str, weight_dtype_str, norm_mode,
use_plugin):
""" This test compares the MOE result to a simple reference implementation using torch """
dtype = tensorrt_llm.str_dtype_to_trt(dtype_str)
use_int4_weights = weight_dtype_str == 'int4'
weight_dtype = trt.int8 if use_int4_weights else tensorrt_llm.str_dtype_to_trt(
weight_dtype_str)
quant_mode = QuantMode(0)
if weight_dtype != dtype:
quant_mode = QuantMode.use_weight_only(
use_int4_weights=use_int4_weights)
ffn_hidden_size = 4 * hidden_size
self.create_weights(num_experts,
hidden_size,
ffn_hidden_size,
bias,
dtype,
weight_dtype,
is_gated=is_gated_activation(actfn))
input_data = gen_uniform_weights(
(num_sequences, sequence_lengths, hidden_size),
dtype=trt_dtype_to_torch(dtype))
# construct trt network
trt_res = self.trtImpl(input_data,
num_experts,
top_k,
hidden_size,
ffn_hidden_size,
actfn,
bias,
dtype,
weight_dtype=weight_dtype,
quant_mode=quant_mode,
norm_mode=norm_mode,
use_plugin=use_plugin)['output'].float()
ref = self.referenceImpl(input_data, top_k, actfn, weight_dtype,
quant_mode, norm_mode).cpu().float()
tolerances = {
'float32': 1e-2,
'float16': 5e-2,
'bfloat16': 5e-2,
'int8': 2e-1,
'int4': 2e-1,
}
# NOTE: There is a known issue where similar routing values result in selecting a different expert to the reference
# This shouldn't cause issues in production, but will cause large deviations in the test results
np.testing.assert_allclose(trt_res,
ref,
rtol=tolerances[weight_dtype_str],
atol=tolerances[weight_dtype_str])
@staticmethod
def get_mlp_params():
params = []
for actfn in ('gelu', 'geglu'):
params += [('float32', actfn, True), ('float16', actfn, True),
('bfloat16', actfn, True), ('int8', actfn, True),
('int4', actfn, True)]
# OOTB tests
# TODO: Support ootb path with getSMVersion() < 90, quantization:
if getSMVersion() >= 90:
params += [('float32', actfn, False), ('float16', actfn, False),
('bfloat16', actfn, False)]
return params
@parameterized.expand(get_mlp_params(), name_func=unittest_name_func)
def test_mlp_comparison(self, dtype_str, actfn, use_plugin):
""" This test uses one expert and compares the result to a plain MLP """
skip_bf16_pre_ampere(dtype_str)
use_int4_weights = dtype_str == 'int4'
weight_dtype = trt.int8 if use_int4_weights else tensorrt_llm.str_dtype_to_trt(
dtype_str)
dtype = weight_dtype
quant_mode = QuantMode(0)
hidden_size = 8
if dtype_str == 'int8' or dtype_str == 'int4':
dtype = tensorrt_llm.str_dtype_to_trt("float16")
hidden_size = 64
quant_mode = QuantMode.use_weight_only(
use_int4_weights=use_int4_weights)
num_sequences = 5
sequence_lengths = 4
num_experts = 1
top_k = 1
bias = True
ffn_hidden_size = 4 * hidden_size
self.create_weights(num_experts,
hidden_size,
ffn_hidden_size,
bias,
dtype,
weight_dtype,
is_gated=is_gated_activation(actfn))
input_data = gen_uniform_weights(
(num_sequences, sequence_lengths, hidden_size),
dtype=trt_dtype_to_torch(dtype))
def MLP(network, trt_key, _):
mlp_type = tensorrt_llm.layers.GatedMLP if is_gated_activation(
actfn) else tensorrt_llm.layers.MLP
mlp = mlp_type(hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=gated2act(actfn),
bias=bias,
quant_mode=quant_mode,
dtype=dtype)
# Quantize the weights manually so the results are comparable
fc1_qd = quant_dequant(self.fc1_weights[0].cpu(), quant_mode)
if is_gated_activation(actfn):
# Note that the MLP uses the opposite convention to the GLU paper for naming,
# the gate is the matrix the activations are NOT applied to
gate, fc1_qd = fc1_qd.chunk(2, dim=0)
mlp.gate.weight.value = np.ascontiguousarray(
torch_to_numpy(gate))
mlp.fc.weight.value = np.ascontiguousarray(torch_to_numpy(fc1_qd))
fc2_qd = quant_dequant(self.fc2_weights[0].cpu(), quant_mode)
mlp.proj.weight.value = np.ascontiguousarray(torch_to_numpy(fc2_qd))
if bias:
fc1_bias = self.fc1_bias[0].cpu()
if is_gated_activation(actfn):
gate, fc1_bias = fc1_bias.chunk(2, dim=0)
mlp.gate.bias.value = np.ascontiguousarray(
torch_to_numpy(gate))
mlp.fc.bias.value = np.ascontiguousarray(
torch_to_numpy(fc1_bias))
mlp.proj.bias.value = np.ascontiguousarray(
torch_to_numpy(self.fc2_bias[0].cpu()))
output = mlp(trt_key).trt_tensor
output.name = 'mlp_output'
network.mark_output(output)
output.dtype = dtype
res = self.trtImpl(input_data,
num_experts,
top_k,
hidden_size,
ffn_hidden_size,
actfn,
bias,
dtype,
weight_dtype=weight_dtype,
quant_mode=quant_mode,
custom_network=MLP,
use_plugin=use_plugin)
tolerances = {
'float32': 1e-2,
'float16': 1e-2
if getSMVersion() >= 75 else 1e-1, # Some issues for geglu on volta
'bfloat16': 1e-1,
'int8': 2e-1,
'int4': 2e-1,
}
np.testing.assert_allclose(res['output'].float(),
res['mlp_output'].float(),
rtol=tolerances[dtype_str],
atol=tolerances[dtype_str])
def set_weight_layer(self, input_weights, weight, scale, quant_mode):
if quant_mode.is_weight_only():
torch_transpose = torch.transpose(input_weights, 1,
2).contiguous().cpu()
type = torch.quint4x2 if quant_mode.is_int4_weight_only(
) else torch.int8
processed_torch_weights, torch_weight_scales = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
torch_transpose, type)
# Change the shape to what moe expects without touching the underlying format
weight.value = np.ascontiguousarray(
torch_to_numpy(processed_torch_weights))
scale.value = np.ascontiguousarray(
torch_to_numpy(torch_weight_scales))
else:
weight.value = np.ascontiguousarray(torch_to_numpy(input_weights))
def trtImpl(self,
input_data,
num_experts,
top_k,
hidden_size,
ffn_hidden_size,
actfn,
bias,
dtype: trt.DataType,
weight_dtype: trt.DataType = None,
quant_mode=QuantMode(0),
norm_mode=MoeConfig.ExpertScaleNormalizationMode.NONE,
finished=None,
custom_network=None,
use_plugin=True):
builder = tensorrt_llm.Builder()
net = builder.create_network()
if use_plugin:
net.plugin_config.set_moe_plugin(dtype)
with tensorrt_llm.net_guard(net):
network = tensorrt_llm.default_trtnet()
trt_key = Tensor(name='input_hidden_states',
shape=tuple(input_data.shape),
dtype=dtype)
trt_finished = Tensor(name='input_finished',
shape=tuple(finished.shape),
dtype=tensorrt_llm.str_dtype_to_trt(
'bool')) if finished is not None else None
moe = tensorrt_llm.layers.MOE(moe_config=MoeConfig(
num_experts=num_experts,
top_k=top_k,
normalization_mode=norm_mode),
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
hidden_act=actfn,
bias=bias,
dtype=dtype,
quant_mode=quant_mode)
moe.router.weight.value = torch_to_numpy(self.router_weights.cpu())
self.set_weight_layer(self.fc1_weights, moe.experts_weight_1,
moe.experts_scale_1, quant_mode)
self.set_weight_layer(self.fc2_weights, moe.experts_weight_2,
moe.experts_scale_2, quant_mode)
if bias:
moe.experts_bias_1.value = torch_to_numpy(self.fc1_bias.cpu())
moe.experts_bias_2.value = torch_to_numpy(self.fc2_bias.cpu())
if custom_network:
custom_network(network, trt_key, trt_finished)
output = moe(trt_key, trt_finished).trt_tensor
output.name = 'output'
network.mark_output(output)
output.dtype = dtype
# trt run
build_engine = EngineFromNetwork(
(builder.trt_builder, net.trt_network),
config=CreateConfig(fp16=(dtype == trt.float16),
bf16=(dtype == trt.bfloat16),
int8=(weight_dtype == trt.int8),
precision_constraints='obey',
builder_optimization_level=4))
assert build_engine is not None
with TrtRunner(build_engine) as runner:
feed_dict = {
'input_hidden_states': input_data,
}
if finished is not None:
feed_dict['input_finished'] = finished
outputs = runner.infer(feed_dict=feed_dict)
return outputs
def referenceImpl(self, inputs, k, actfn, weight_dtype, quant_mode,
norm_mode):
# Always run the ref implementation at full precision TODO is this a good choice?
inputs = inputs.cuda().float()
inputs_merged = inputs.view(-1, inputs.shape[-1])
routing = torch.matmul(inputs_merged, self.router_weights.T.float())
assert routing.shape == (inputs_merged.shape[0],
self.router_weights.shape[0])
router_probs = torch.softmax(routing, 1, dtype=inputs.dtype)
assert routing.shape == router_probs.shape
topk = torch.topk(router_probs, k)
assert topk.indices.shape == (router_probs.shape[0], k)
results = torch.zeros_like(inputs_merged)
for i, (scales, experts) in enumerate(zip(topk.values, topk.indices)):
if norm_mode == MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE:
scales /= sum(scales)
input = inputs_merged[i, :]
for scale, expert in zip(scales, experts):
fc1_qd = quant_dequant(self.fc1_weights[expert], quant_mode)
if is_gated_activation(actfn):
fc1 = gated_matmul(input, fc1_qd.float(),
self.fc1_bias[expert].float(), actfn)
else:
fc1 = torch.matmul(
input,
fc1_qd.T.float()) + self.fc1_bias[expert].float()
fc1 = doact(fc1, actfn)
fc2_qd = quant_dequant(self.fc2_weights[expert], quant_mode)
final = torch.matmul(
fc1, fc2_qd.T.float()) + self.fc2_bias[expert].float()
assert final.shape == (inputs.shape[-1], )
results[i] += scale * final
return results.view(*inputs.shape)
if __name__ == "__main__":
unittest.main()