# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import sys import tempfile import unittest from pathlib import Path import numpy as np import pytest import torch from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from mamba_ssm.utils.generation import InferenceParams from mamba_ssm.utils.hf import load_config_hf from parameterized import parameterized import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.layers.ssm import MambaParameters from tensorrt_llm.network import net_guard sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from examples.mamba.convert_checkpoint import (convert_from_hf_checkpoint, convert_hf_mamba) sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.llm_data import llm_models_root from utils.util import skip_bf16_pre_ampere, unittest_name_func class TestMamba(unittest.TestCase): def _gen_tensorrt_llm_mamba(self, hf_config, hf_path, hf_mamba, load_mode, dtype): config = { 'architecture': 'MambaLMHeadModel', 'dtype': dtype, 'logits_dtype': 'float32', 'hidden_size': hf_config.d_model, 'num_hidden_layers': hf_config.n_layer, 'vocab_size': hf_config.vocab_size, 'ssm_cfg': MambaParameters(**hf_config.ssm_cfg).__dict__, 'rms_norm': hf_config.rms_norm, 'residual_in_fp32': hf_config.residual_in_fp32, 'pad_vocab_size_multiple': hf_config.pad_vocab_size_multiple, 'hidden_act': 'silu', 'num_attention_heads': 1, } config = tensorrt_llm.models.PretrainedConfig.from_dict(config) if load_mode == 'from_checkpoint': weights = convert_from_hf_checkpoint(model_dir=hf_path, dtype=dtype) else: weights = convert_hf_mamba(hf_mamba, rank=0, dtype=dtype) tensorrt_llm_mamba = tensorrt_llm.models.MambaLMHeadModel(config) tensorrt_llm_mamba.load(weights) return tensorrt_llm_mamba def _gen_tensorrt_llm_network(self, network, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype): tensorrt_llm_mamba = self._gen_tensorrt_llm_mamba( hf_config, hf_path, hf_mamba, load_mode, dtype) with net_guard(network): network.set_named_parameters(tensorrt_llm_mamba.named_parameters()) inputs = tensorrt_llm_mamba.prepare_inputs(batch_size, input_len, input_len + output_len, use_cache=False) # Prepare tensorrt_llm_mamba(**inputs) return network def _gen_tensorrt_llm_engine(self, model_name, gemm_plugin, mamba_conv1d_plugin, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype, remove_padding): builder = Builder() with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( name=model_name, precision=dtype, timing_cache='model.cache', ) network = builder.create_network() network.plugin_config.to_legacy_setting() network.plugin_config.remove_input_padding = remove_padding network.plugin_config.paged_state = False if gemm_plugin: network.plugin_config.set_gemm_plugin(dtype) if mamba_conv1d_plugin: network.plugin_config.set_mamba_conv1d_plugin(dtype) else: network.plugin_config.set_mamba_conv1d_plugin(None) self._gen_tensorrt_llm_network(network, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype) engine_buffer = builder.build_engine(network, builder_config) return engine_buffer def _gen_tensorrt_llm_runtime(self, log_level, model_name, gemm_plugin, mamba_conv1d_plugin, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype, remove_padding): tensorrt_llm.logger.set_level(log_level) mapping = tensorrt_llm.Mapping() engine_buffer = self._gen_tensorrt_llm_engine( model_name, gemm_plugin, mamba_conv1d_plugin, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype, remove_padding) runtime = tensorrt_llm.runtime.generation._Runtime( engine_buffer, mapping) return runtime, engine_buffer @parameterized.expand([ (True, True, 'float16', False), (False, True, 'float16', False), (True, True, 'bfloat16', False), (False, True, 'bfloat16', False), (True, False, 'float16', False), (False, False, 'float16', False), (True, False, 'bfloat16', False), (False, False, 'bfloat16', False), (True, True, 'float16', True), (False, True, 'float16', True), (True, True, 'bfloat16', True), (False, True, 'bfloat16', True), ], name_func=unittest_name_func) def test_mamba(self, gemm_plugin, mamba_conv1d_plugin, dtype, remove_padding): # Skip tests that are not supported in pre-ampere architecture skip_bf16_pre_ampere(dtype) RANDOM_SEEDS = [1, 4, 5, 8] seed_idx = random.randint(0, len(RANDOM_SEEDS) - 1) torch.manual_seed(RANDOM_SEEDS[seed_idx]) model_name = 'mamba' log_level = 'error' batch_size = 4 input_len = 16 output_len = 2 load_mode = 'from_model' hf_path = '' d_model = 128 hf_config = MambaConfig(d_model=d_model, n_layer=2, vocab_size=128) # get hf mamba hf_mamba = MambaLMHeadModel(hf_config, device='cuda', dtype=str_dtype_to_torch(dtype)) # get tensorrt llm mamba rumtime runtime, _ = self._gen_tensorrt_llm_runtime( log_level, model_name, gemm_plugin, mamba_conv1d_plugin, hf_config, hf_path, hf_mamba, load_mode, batch_size, input_len, output_len, dtype, remove_padding) # prepare buffers mamba_d_inner = hf_mamba.backbone.layers[0].mixer.d_inner mamba_d_conv = hf_mamba.backbone.layers[0].mixer.d_conv mamba_d_state = hf_mamba.backbone.layers[0].mixer.d_state if mamba_conv1d_plugin: conv_state_shape = ( batch_size, mamba_d_conv - 1, mamba_d_inner, ) else: conv_state_shape = ( batch_size, mamba_d_inner, mamba_d_conv - 1, ) ssm_state_shape = ( batch_size, mamba_d_state, mamba_d_inner, ) present_conv_states = [] present_conv_states_1 = [] present_ssm_states = [] for _ in range(hf_config.n_layer): present_conv_states.append( torch.zeros(conv_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda')) present_conv_states_1.append( torch.empty(conv_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda')) present_ssm_states.append( torch.empty(ssm_state_shape, dtype=str_dtype_to_torch(dtype), device='cuda')) # compare context if remove_padding: ctx_last_token_ids = torch.randint(1, input_len + 1, (batch_size, ), dtype=torch.int32) host_ctx_lengths = ctx_last_token_ids.detach().clone().cpu() ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0, dtype=torch.int32).to('cuda') total_num_tokens = ctx_last_token_ids[batch_size - 1] else: ctx_last_token_ids = input_len * torch.ones( (batch_size, ), dtype=torch.int32, device='cuda') total_num_tokens = batch_size * input_len host_ctx_lengths = input_len * torch.ones( (batch_size, ), dtype=torch.int32) ctx_ids = torch.randint(100, (total_num_tokens, )).int().cuda() if not remove_padding: ctx_ids = ctx_ids.view(-1, input_len) ctx_host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32) step1_id = torch.randint(100, (batch_size, )).int().cuda() if not remove_padding: step1_id = step1_id.view(-1, 1) with torch.no_grad(): if remove_padding: ref = torch.empty(batch_size, d_model) gen_ref = torch.empty(batch_size, d_model) for i in range(batch_size): start_id = 0 if i == 0 else ctx_last_token_ids[i - 1] end_id = ctx_last_token_ids[i] infer_params = InferenceParams(max_seqlen=end_id - start_id + output_len, max_batch_size=1) part_ctx_ids = torch.unsqueeze(ctx_ids[start_id:end_id], dim=0) part_hf_outputs = hf_mamba.forward( part_ctx_ids, inference_params=infer_params) infer_params.seqlen_offset += end_id - start_id torch.cuda.synchronize() ref[i][:] = part_hf_outputs.logits[0, -1, :] part_step1_id = step1_id[i].view(1, 1) part_hf_gen_outputs = hf_mamba.forward( part_step1_id, inference_params=infer_params) torch.cuda.synchronize() gen_ref[i][:] = part_hf_gen_outputs.logits[0, -1, :] infer_params.seqlen_offset += 1 else: infer_params = InferenceParams(max_seqlen=input_len + output_len, max_batch_size=batch_size) hf_outputs = hf_mamba.forward(ctx_ids, inference_params=infer_params) infer_params.seqlen_offset += ctx_ids.shape[1] ref = hf_outputs.logits[:, -1, :] torch.cuda.synchronize() ctx_buffer = { 'input_ids': ctx_ids, 'last_token_ids': ctx_last_token_ids, 'host_request_types': ctx_host_request_types, 'host_context_lengths': host_ctx_lengths, } for idx in range(hf_config.n_layer): ctx_buffer[f'past_conv_state_{idx}'] = present_conv_states[idx] ctx_buffer[f'present_conv_state_{idx}'] = present_conv_states_1[idx] ctx_buffer[f'past_ssm_state_{idx}'] = present_ssm_states[idx] ctx_buffer[f'present_ssm_state_{idx}'] = present_ssm_states[idx] ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} context = runtime.ctx_context runtime._set_shape(context, ctx_shape) runtime._set_buffer(context, ctx_buffer) runtime._run(context) torch.cuda.synchronize() res = ctx_buffer['logits'] np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), res.to(torch.float32).cpu().numpy(), atol=1e-2) # compare generation gen_last_token_ids = torch.ones((batch_size, ), dtype=torch.int32, device='cuda') if remove_padding: gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0, dtype=torch.int32).to('cuda') gen_host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32) if not remove_padding: with torch.no_grad(): hf_outputs = hf_mamba.forward(step1_id, inference_params=infer_params) infer_params.seqlen_offset += step1_id.shape[1] torch.cuda.synchronize() gen_ref = hf_outputs.logits[:, -1, :] step1_buffer = { 'input_ids': step1_id, 'last_token_ids': gen_last_token_ids, 'host_request_types': gen_host_request_types, 'host_context_lengths': host_ctx_lengths, } for idx in range(hf_config.n_layer): step1_buffer[f'past_conv_state_{idx}'] = present_conv_states_1[idx] step1_buffer[f'present_conv_state_{idx}'] = present_conv_states[idx] step1_buffer[f'past_ssm_state_{idx}'] = present_ssm_states[idx] step1_buffer[f'present_ssm_state_{idx}'] = present_ssm_states[idx] step1_shape = {k: v.shape for k, v in step1_buffer.items()} context = runtime.context_1 runtime._set_shape(context, step1_shape) runtime._set_buffer(context, step1_buffer) runtime._run(context) torch.cuda.synchronize() res = step1_buffer['logits'] np.testing.assert_allclose(gen_ref.to(torch.float32).cpu().numpy(), res.to(torch.float32).cpu().numpy(), atol=1e-2) @parameterized.expand([ ('mamba-130m', 'from_checkpoint'), ('mamba-130m', 'from_model'), ], name_func=unittest_name_func) def test_loaders(self, path, load_mode): model_root = llm_models_root() if model_root is None: pytest.skip('Skipping since real weights are unavailable.') hf_path = Path(model_root, path) if not hf_path.exists(): pytest.skip(f'Skipping since the path {hf_path} does not exist.') dtype = 'float16' # get hf mamba hf_mamba = MambaLMHeadModel.from_pretrained( hf_path, device='cpu', dtype=str_dtype_to_torch(dtype)) # get tensort llm mamba config_data = load_config_hf(hf_path) hf_config = MambaConfig(**config_data) tensorrt_llm_mamba = self._gen_tensorrt_llm_mamba( hf_config, hf_path, hf_mamba, load_mode, dtype) def has_bias(torch_layer): return hasattr(torch_layer, 'bias') and torch_layer.bias is not None # token embedding np.testing.assert_allclose( tensorrt_llm_mamba.backbone.vocab_embedding.weight.raw_value, hf_mamba.backbone.embedding.weight.cpu().detach(), atol=1e-3) # output np.testing.assert_allclose(tensorrt_llm_mamba.lm_head.weight.raw_value, hf_mamba.lm_head.weight.cpu().detach(), atol=1e-3) # norm np.testing.assert_allclose( tensorrt_llm_mamba.backbone.norm_f.weight.raw_value, hf_mamba.backbone.norm_f.weight.cpu().detach(), atol=1e-3) if has_bias(hf_mamba.backbone.norm_f): np.testing.assert_allclose( tensorrt_llm_mamba.backbone.norm_f.bias.raw_value, hf_mamba.backbone.norm_f.bias.cpu().detach(), atol=1e-3) # Checking all of the layers takes too much time, just check one random layer l = np.random.randint(0, tensorrt_llm_mamba.config.num_hidden_layers) print(f"Checking Layer-{l} weights ...", flush=True) # layer{l}.input_layernorm np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].input_layernorm.weight. raw_value, hf_mamba.backbone.layers[l].norm.weight.cpu().detach(), atol=1e-3) if has_bias(hf_mamba.backbone.layers[l]): np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].input_layernorm.bias. raw_value, hf_mamba.backbone.layers[l].norm.bias.cpu().detach(), atol=1e-3) # layer{l}.ssm.A A_hf = -torch.exp(hf_mamba.backbone.layers[l].mixer.A_log.float()) A_hf_permute = A_hf.cpu().detach().permute([1, 0]).contiguous() np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.A.raw_value, A_hf_permute, atol=1e-3) # layer{l}.ssm.D np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.D.raw_value, hf_mamba.backbone.layers[l].mixer.D.float().cpu().detach(), atol=1e-3) # layer{l}.ssm.dt_bias np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.dt_bias.raw_value, hf_mamba.backbone.layers[l].mixer.dt_proj.bias.cpu().to( torch.float32).detach(), atol=1e-3) # layer{l}.ssm.in_proj d_inner = tensorrt_llm_mamba.backbone.layers[ l].ssm.in_proj_x.weight.raw_value.shape[0] in_proj_x_hf = hf_mamba.backbone.layers[l].mixer.in_proj.weight[ 0:d_inner, ] in_proj_z_hf = hf_mamba.backbone.layers[l].mixer.in_proj.weight[ d_inner:, ] np.testing.assert_allclose(tensorrt_llm_mamba.backbone.layers[l].ssm. in_proj_x.weight.raw_value, in_proj_x_hf.cpu().detach(), atol=1e-3) np.testing.assert_allclose(tensorrt_llm_mamba.backbone.layers[l].ssm. in_proj_z.weight.raw_value, in_proj_z_hf.cpu().detach(), atol=1e-3) if has_bias(hf_mamba.backbone.layers[l].mixer.in_proj): in_proj_bias_x_hf = hf_mamba.backbone.layers[l].mixer.in_proj.bias[ 0:d_inner] in_proj_bias_z_hf = hf_mamba.backbone.layers[l].mixer.in_proj.bias[ d_inner:] np.testing.assert_allclose(tensorrt_llm_mamba.backbone.layers[l]. ssm.in_proj_x.bias.raw_value, in_proj_bias_x_hf.cpu().detach(), atol=1e-3) np.testing.assert_allclose(tensorrt_llm_mamba.backbone.layers[l]. ssm.in_proj_z.bias.raw_value, in_proj_bias_z_hf.cpu().detach(), atol=1e-3) # layer{l}.ssm.conv1d np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.conv1d.weight.raw_value, hf_mamba.backbone.layers[l].mixer.conv1d.weight.unsqueeze( 3).cpu().detach(), atol=1e-3) if has_bias(hf_mamba.backbone.layers[l].mixer.conv1d): np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.conv1d.bias.raw_value, hf_mamba.backbone.layers[l].mixer.conv1d.bias.cpu().detach(), atol=1e-3) # layer{l}.ssm.x_proj np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.x_proj.weight.raw_value, hf_mamba.backbone.layers[l].mixer.x_proj.weight.cpu().detach(), atol=1e-3) # layer{l}.ssm.dt_proj np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.dt_proj.weight.raw_value, hf_mamba.backbone.layers[l].mixer.dt_proj.weight.cpu().detach(), atol=1e-3) # layer{l}.ssm.out_proj np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.out_proj.weight.raw_value, hf_mamba.backbone.layers[l].mixer.out_proj.weight.cpu().detach(), atol=1e-3) if has_bias(hf_mamba.backbone.layers[l].mixer.out_proj): np.testing.assert_allclose( tensorrt_llm_mamba.backbone.layers[l].ssm.out_proj.bias. raw_value, hf_mamba.backbone.layers[l].mixer.out_proj.bias.cpu().detach(), atol=1e-3) return