# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import os import sys import unittest import numpy as np import pytest import tensorrt as trt import torch from parameterized import parameterized from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner import tensorrt_llm from tensorrt_llm import Tensor from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch from tensorrt_llm.layers.moe import MOEExpertScaleNormalizationMode from tensorrt_llm.quantization import QuantMode sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import getSMVersion default_actfn = 'gelu' def make_tuple(num_experts=4, topk=1, hidden_size=8, num_sequences=5, sequence_length=4, actfn=default_actfn, bias=True, dtype='float32', weight_dtype=None, norm_mode=MOEExpertScaleNormalizationMode.NONE): if weight_dtype is None: weight_dtype = dtype return (num_experts, topk, hidden_size, num_sequences, sequence_length, actfn, bias, dtype, weight_dtype, norm_mode) def gen_uniform_weights(*args, **kwargs): return (torch.rand(*args, **kwargs) * 2 - 1).contiguous() def quant_dequant(weights, weight_dtype): if weight_dtype != trt.int8: return weights # use the test version `_symmetric_...` to get the non-interleaved weights quant_weights, _, torch_weight_scales = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix( weights.T.cpu().contiguous(), trt_dtype_to_torch(weight_dtype)) quant_weights = quant_weights.to(dtype=weights.dtype) result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() return result.to(device=weights.device) GATED_TO_ACT = { 'swiglu': 'silu', 'geglu': 'gelu', } def is_gated_activation(actfn): return actfn in GATED_TO_ACT def gated2act(actfn): if is_gated_activation(actfn): return GATED_TO_ACT[actfn] return actfn def doact(input, actfn): assert not is_gated_activation(actfn) if actfn == 'gelu': return torch.nn.functional.gelu(input) if actfn == 'relu': return torch.nn.functional.relu(input) if actfn == 'silu': return torch.nn.functional.silu(input) assert actfn == "identity" return input # Identity def gated_matmul(input, weights, bias, actfn): assert is_gated_activation(actfn) fc1 = torch.matmul(input, weights.T) + bias fc1, gate = fc1.chunk(2, dim=-1) return fc1 * doact(gate, gated2act(actfn)) class TestFunctional(unittest.TestCase): def setUp(self): # There is a known precision issues where the topk may select different experts when the routing probabilities are similar. # This causes a completely different output for the affected tokens. So we set the seed to prevent sporadic failures # This shouldn't be a problem for most practical applications as it means the experts are equally good choices torch.manual_seed(0x766E) def eye(self, shape, dtype, device='cuda'): """ Utility function for creating expert weights as an identity matrix for easy debugging """ eye = torch.eye(shape[-2], m=shape[-1], dtype=dtype, device=device) eye = eye.repeat(*shape[:-2], 1, 1) return eye @staticmethod def get_params(): params = [] # Some default values to use for most test cases for experts in [1, 4, 42, 1024]: for topk in [1, 2, 3]: if topk < experts: params += [ make_tuple(num_experts=experts, topk=topk, dtype='float16') ] for num_tokens in [1, 42, 100]: for sequence_length in [1, 3, 42]: num_sequences = math.ceil(num_tokens / sequence_length) params += [ make_tuple(num_sequences=num_sequences, sequence_length=sequence_length, dtype='float16') ] # Add a test for float32 params += [ make_tuple(dtype='float32'), # Try 5 because non-power 2 use a different topk kernel make_tuple(num_experts=5, dtype='float32') ] # Add a test for bfloat16 if getSMVersion() >= 80: params += [ make_tuple(dtype='bfloat16'), # Try 5 because non-power 2 use a different topk kernel make_tuple(num_experts=5, dtype='bfloat16') ] # Add some cases for quantized dtype params += [ make_tuple(dtype='float16', hidden_size=64, weight_dtype='int8'), make_tuple(dtype='float16', hidden_size=64, num_experts=5, weight_dtype='int8'), ] if getSMVersion() >= 80: params += [ make_tuple(dtype='bfloat16', hidden_size=64, weight_dtype='int8') ] # Test all activation functions with float16 for actfn in ('relu', 'silu', 'gelu', 'swiglu', 'geglu', 'identity'): if actfn == default_actfn: continue # Dont need to retest the one every other case uses params += [make_tuple(actfn=actfn, dtype='float16')] # Test gated with all data types as it has a different path for actfn in ('swiglu', 'geglu'): if actfn == default_actfn: continue # Dont need to retest the one every other case uses params += [ make_tuple(actfn=actfn, dtype='float32'), make_tuple(actfn=actfn, hidden_size=64, dtype='float16', weight_dtype='int8'), ] if getSMVersion() >= 80: params += [make_tuple(actfn=actfn, dtype='bfloat16')] # Test no bias params += [ make_tuple(bias=False, dtype='float32'), make_tuple(bias=False, dtype='float16'), make_tuple(dtype='float16', hidden_size=64, weight_dtype='int8', bias=False) ] # Test renormalization params += [ make_tuple(topk=2, dtype='float32', norm_mode=MOEExpertScaleNormalizationMode.RENORMALIZE), make_tuple(topk=2, dtype='float16', norm_mode=MOEExpertScaleNormalizationMode.RENORMALIZE), make_tuple(dtype='float16', topk=2, hidden_size=64, weight_dtype='int8', norm_mode=MOEExpertScaleNormalizationMode.RENORMALIZE), # Renorm affects the final accumulate, so sanity check with no bias too make_tuple(norm_mode=MOEExpertScaleNormalizationMode.RENORMALIZE, topk=2, dtype='float16', bias=False), ] if getSMVersion() >= 80: params += [ make_tuple( dtype='bfloat16', topk=2, norm_mode=MOEExpertScaleNormalizationMode.RENORMALIZE) ] return params def custom_name_func(testcase_func, param_num, param): return "%s_%s" % ( testcase_func.__name__, parameterized.to_safe_name("_".join(str(x) for x in param.args)), ) def create_weights(self, num_experts, hidden_size, ffn_hidden_size, bias, dtype, weight_dtype, is_gated): self.router_weights = torch.randn((num_experts, hidden_size), dtype=trt_dtype_to_torch(dtype), device="cuda") # Use a uniform scale for int8 so the quantization has a well-behaved dynamic range genfn = gen_uniform_weights if weight_dtype == trt.int8 else torch.randn fc1_out_size = ffn_hidden_size * 2 if is_gated else ffn_hidden_size self.fc1_weights = genfn((num_experts, fc1_out_size, hidden_size), dtype=trt_dtype_to_torch(dtype), device="cuda") self.fc2_weights = genfn((num_experts, hidden_size, ffn_hidden_size), dtype=trt_dtype_to_torch(dtype), device="cuda") bias_tensor_func = genfn if bias else torch.zeros self.fc1_bias = bias_tensor_func((num_experts, fc1_out_size), dtype=trt_dtype_to_torch(dtype), device="cuda") self.fc2_bias = bias_tensor_func((num_experts, hidden_size), dtype=trt_dtype_to_torch(dtype), device="cuda") @parameterized.expand(get_params(), name_func=custom_name_func) def test_mixture_of_experts(self, num_experts, top_k, hidden_size, num_sequences, sequence_lengths, actfn, bias, dtype_str, weight_dtype_str, norm_mode): """ This test compares the MOE plugin result to a simple reference implementation using torch """ dtype = tensorrt_llm.str_dtype_to_trt(dtype_str) weight_dtype = tensorrt_llm.str_dtype_to_trt(weight_dtype_str) ffn_hidden_size = 4 * hidden_size self.create_weights(num_experts, hidden_size, ffn_hidden_size, bias, dtype, weight_dtype, is_gated=is_gated_activation(actfn)) input_data = gen_uniform_weights( (num_sequences, sequence_lengths, hidden_size), dtype=trt_dtype_to_torch(dtype)) # construct trt network trt_res = self.trtImpl(input_data, num_experts, top_k, hidden_size, ffn_hidden_size, actfn, bias, dtype, weight_dtype=weight_dtype, norm_mode=norm_mode)['output'].float() ref = self.referenceImpl(input_data, top_k, actfn, weight_dtype, norm_mode).cpu().float() tolerances = { 'float32': 1e-2, 'float16': 5e-2, 'bfloat16': 5e-2, 'int8': 2e-1, } # NOTE: There is a known issue where similar routing values result in selecting a different expert to the reference # This shouldn't cause issues in production, but will cause large deviations in the test results np.testing.assert_allclose(trt_res, ref, rtol=tolerances[weight_dtype_str], atol=tolerances[weight_dtype_str]) @staticmethod def get_mlp_params(): params = [] for actfn in ('gelu', 'geglu'): params += [('float32', actfn), ('float16', actfn), ('bfloat16', actfn), ('int8', actfn)] return params @parameterized.expand(get_mlp_params(), name_func=custom_name_func) def test_mlp_comparison(self, dtype_str, actfn): """ This test uses one expert and compares the result to a plain MLP """ if getSMVersion() < 80 and dtype_str == 'bfloat16': pytest.skip("Skip bf16 tests on arch < sm80") dtype = tensorrt_llm.str_dtype_to_trt(dtype_str) weight_dtype = dtype hidden_size = 8 if dtype_str == 'int8': dtype = tensorrt_llm.str_dtype_to_trt("float16") hidden_size = 64 num_sequences = 5 sequence_lengths = 4 num_experts = 1 top_k = 1 bias = True ffn_hidden_size = 4 * hidden_size self.create_weights(num_experts, hidden_size, ffn_hidden_size, bias, dtype, weight_dtype, is_gated=is_gated_activation(actfn)) input_data = gen_uniform_weights( (num_sequences, sequence_lengths, hidden_size), dtype=trt_dtype_to_torch(dtype)) def MLP(network, trt_key, _): mlp_type = tensorrt_llm.layers.GatedMLP if is_gated_activation( actfn) else tensorrt_llm.layers.MLP mlp = mlp_type(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, hidden_act=gated2act(actfn), bias=bias, dtype=dtype) # Quantize the weights manually so the results are comparable fc1_qd = quant_dequant(self.fc1_weights[0].cpu(), weight_dtype) if is_gated_activation(actfn): # Note that the MLP uses the opposite convention to the GLU paper for naming, # the gate is the matrix the activations are NOT applied to gate, fc1_qd = fc1_qd.chunk(2, dim=0) mlp.gate.weight.value = np.ascontiguousarray( torch_to_numpy(gate)) mlp.fc.weight.value = np.ascontiguousarray(torch_to_numpy(fc1_qd)) fc2_qd = quant_dequant(self.fc2_weights[0].cpu(), weight_dtype) mlp.proj.weight.value = np.ascontiguousarray(torch_to_numpy(fc2_qd)) if bias: fc1_bias = self.fc1_bias[0].cpu() if is_gated_activation(actfn): gate, fc1_bias = fc1_bias.chunk(2, dim=0) mlp.gate.bias.value = np.ascontiguousarray( torch_to_numpy(gate)) mlp.fc.bias.value = np.ascontiguousarray( torch_to_numpy(fc1_bias)) mlp.proj.bias.value = np.ascontiguousarray( torch_to_numpy(self.fc2_bias[0].cpu())) output = mlp(trt_key).trt_tensor output.name = 'mlp_output' network.mark_output(output) output.dtype = dtype res = self.trtImpl(input_data, num_experts, top_k, hidden_size, ffn_hidden_size, actfn, bias, dtype, weight_dtype=weight_dtype, custom_network=MLP) tolerances = { 'float32': 1e-2, 'float16': 1e-2 if getSMVersion() >= 75 else 1e-1, # Some issues for geglu on volta 'bfloat16': 1e-1, 'int8': 2e-1, } np.testing.assert_allclose(res['output'].float(), res['mlp_output'].float(), rtol=tolerances[dtype_str], atol=tolerances[dtype_str]) def set_weight_layer(self, input_weights, weight, scale, weight_dtype: trt.DataType): torch_transpose = torch.transpose(input_weights, 1, 2).contiguous().cpu() if weight_dtype == trt.int8: processed_torch_weights, torch_weight_scales = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( torch_transpose, trt_dtype_to_torch(weight_dtype)) weight.value = np.ascontiguousarray( torch_to_numpy(processed_torch_weights)) scale.value = np.ascontiguousarray( torch_to_numpy(torch_weight_scales)) else: weight.value = np.ascontiguousarray(torch_to_numpy(torch_transpose)) def trtImpl(self, input_data, num_experts, top_k, hidden_size, ffn_hidden_size, actfn, bias, dtype: trt.DataType, weight_dtype: trt.DataType = None, norm_mode=MOEExpertScaleNormalizationMode.NONE, finished=None, custom_network=None): builder = tensorrt_llm.Builder() net = builder.create_network() with tensorrt_llm.net_guard(net): network = tensorrt_llm.default_trtnet() trt_key = Tensor(name='input_hidden_states', shape=tuple(input_data.shape), dtype=dtype) trt_finished = Tensor(name='input_finished', shape=tuple(finished.shape), dtype=tensorrt_llm.str_dtype_to_trt( 'bool')) if finished is not None else None quant_mode = QuantMode(0) if weight_dtype == trt.int8: quant_mode = QuantMode.use_weight_only(use_int4_weights=False) moe = tensorrt_llm.layers.MOE(num_experts=num_experts, hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, hidden_act=actfn, top_k=top_k, bias=bias, dtype=dtype, quant_mode=quant_mode, normalization_mode=norm_mode) moe.router.weight.value = torch_to_numpy(self.router_weights.cpu()) self.set_weight_layer(self.fc1_weights, moe.experts_weight_1, moe.experts_scale_1, weight_dtype) self.set_weight_layer(self.fc2_weights, moe.experts_weight_2, moe.experts_scale_2, weight_dtype) if bias: moe.experts_bias_1.value = torch_to_numpy(self.fc1_bias.cpu()) moe.experts_bias_2.value = torch_to_numpy(self.fc2_bias.cpu()) if custom_network: custom_network(network, trt_key, trt_finished) output = moe(trt_key, trt_finished).trt_tensor output.name = 'output' network.mark_output(output) output.dtype = dtype # trt run build_engine = EngineFromNetwork( (builder.trt_builder, net.trt_network), config=CreateConfig(fp16=(dtype == trt.float16), bf16=(dtype == trt.bfloat16), int8=(weight_dtype == trt.int8), precision_constraints='obey', builder_optimization_level=4)) assert build_engine is not None with TrtRunner(build_engine) as runner: feed_dict = { 'input_hidden_states': input_data, } if finished is not None: feed_dict['input_finished'] = finished outputs = runner.infer(feed_dict=feed_dict) return outputs def referenceImpl(self, inputs, k, actfn, weight_dtype, norm_mode): # Always run the ref implementation at full precision TODO is this a good choice? inputs = inputs.cuda().float() inputs_merged = inputs.view(-1, inputs.shape[-1]) routing = torch.matmul(inputs_merged, self.router_weights.T.float()) assert routing.shape == (inputs_merged.shape[0], self.router_weights.shape[0]) router_probs = torch.softmax(routing, 1, dtype=inputs.dtype) assert routing.shape == router_probs.shape topk = torch.topk(router_probs, k) assert topk.indices.shape == (router_probs.shape[0], k) results = torch.zeros_like(inputs_merged) for i, (scales, experts) in enumerate(zip(topk.values, topk.indices)): if norm_mode == MOEExpertScaleNormalizationMode.RENORMALIZE: scales /= sum(scales) input = inputs_merged[i, :] for scale, expert in zip(scales, experts): fc1_qd = quant_dequant(self.fc1_weights[expert], weight_dtype) if is_gated_activation(actfn): fc1 = gated_matmul(input, fc1_qd.float(), self.fc1_bias[expert].float(), actfn) else: fc1 = torch.matmul( input, fc1_qd.T.float()) + self.fc1_bias[expert].float() fc1 = doact(fc1, actfn) fc2_qd = quant_dequant(self.fc2_weights[expert], weight_dtype) final = torch.matmul( fc1, fc2_qd.T.float()) + self.fc2_bias[expert].float() assert final.shape == (inputs.shape[-1], ) results[i] += scale * final return results.view(*inputs.shape) if __name__ == "__main__": unittest.main()