mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
394 lines
12 KiB
Python
394 lines
12 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from ..attention_backend import AttentionMetadata
|
|
from ..distributed import ParallelConfig, TensorParallelMode
|
|
from ..model_config import ModelConfig
|
|
from .linear import Linear
|
|
|
|
|
|
@dataclass
|
|
class MambaCacheParams:
|
|
conv_states: torch.Tensor = torch.Tensor()
|
|
ssm_states: torch.Tensor = torch.Tensor()
|
|
indices: torch.Tensor = torch.Tensor()
|
|
|
|
|
|
class MambaCacheManager:
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
d_model: int,
|
|
d_state: int,
|
|
d_conv: int,
|
|
expand: int,
|
|
n_groups: int,
|
|
head_dim: int,
|
|
num_mamba_layers: int,
|
|
max_batch_size: int,
|
|
dtype: Optional[torch.dtype] = None,
|
|
config: Optional[ModelConfig] = None,
|
|
):
|
|
|
|
config = config or ModelConfig()
|
|
tp_size = config.mapping.tp_size
|
|
|
|
d_inner = d_model * expand
|
|
conv_dim = d_inner + 2 * n_groups * d_state
|
|
nheads = d_inner // head_dim
|
|
|
|
assert nheads % tp_size == 0, "nheads must be divisible by tp_size"
|
|
assert conv_dim % tp_size == 0, "conv_dim must be divisible by tp_size"
|
|
|
|
conv_dim = conv_dim // tp_size
|
|
nheads = nheads // tp_size
|
|
|
|
device = torch.device("cuda")
|
|
|
|
self.conv_states = torch.empty(
|
|
size=[num_mamba_layers, max_batch_size, d_conv - 1, conv_dim],
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
self.ssm_states = torch.empty(
|
|
size=[
|
|
num_mamba_layers,
|
|
max_batch_size,
|
|
nheads,
|
|
d_state,
|
|
head_dim,
|
|
],
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def get_params(self, attn_metadata: AttentionMetadata) -> MambaCacheParams:
|
|
# request_ids is set to None when warming up the engine
|
|
# we set this warmup request to zero position
|
|
request_ids = ([0] if attn_metadata.request_ids is None else
|
|
attn_metadata.request_ids)
|
|
|
|
# this implements the simplest possible mapping, indices = request_ids
|
|
indices = torch.as_tensor(request_ids, dtype=torch.int32)
|
|
|
|
# return cache params
|
|
return MambaCacheParams(self.conv_states, self.ssm_states, indices)
|
|
|
|
|
|
class RMSNormGroup(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_groups: int,
|
|
eps: float = 1e-5,
|
|
dtype: Optional[torch.dtype] = None,
|
|
config: Optional[ModelConfig] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
config = config or ModelConfig()
|
|
tp_size = config.mapping.tp_size
|
|
tp_rank = config.mapping.tp_rank
|
|
|
|
assert (hidden_size %
|
|
num_groups == 0), "hidden_size must be divisible by num_groups"
|
|
|
|
assert (hidden_size %
|
|
tp_size == 0), "hidden_size must be divisible by tp_size"
|
|
|
|
assert (num_groups %
|
|
tp_size == 0), "num_groups must be divisible by tp_size"
|
|
|
|
tp_hidden_size = hidden_size // tp_size
|
|
tp_ngroups = num_groups // tp_size
|
|
|
|
self.num_groups = tp_ngroups
|
|
self.group_size = tp_hidden_size // tp_ngroups
|
|
|
|
self.weight = nn.Parameter(torch.empty(tp_hidden_size, dtype=dtype))
|
|
self.variance_epsilon = eps
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
input_dtype = x.dtype
|
|
y = x.to(torch.float32)
|
|
|
|
# Reshape to [*, num_groups, group_size]
|
|
orig_shape = y.shape
|
|
y = y.view(*y.shape[:-1], self.num_groups, self.group_size)
|
|
|
|
# Compute variance over the group_size dimension
|
|
variance = y.pow(2).mean(-1, keepdim=True)
|
|
|
|
# Normalize within each group
|
|
y = y * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
|
# Reshape back to original shape
|
|
y = y.view(*orig_shape)
|
|
|
|
# Apply the scaling weight
|
|
y = self.weight * y.to(input_dtype)
|
|
|
|
return y
|
|
|
|
|
|
class Mamba2(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
d_model: int,
|
|
d_state: int,
|
|
d_conv: int,
|
|
expand: int,
|
|
n_groups: int,
|
|
head_dim: int,
|
|
chunk_size: int,
|
|
layer_idx: int,
|
|
bias: bool = False,
|
|
conv_bias: bool = True,
|
|
delta_rank: int = 0,
|
|
delta_softplus: bool = True,
|
|
remove_padding: bool = True,
|
|
apply_silu: bool = True,
|
|
rms_norm_eps: float = 1e-5,
|
|
dtype: Optional[torch.dtype] = None,
|
|
config: Optional[ModelConfig] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
config = config or ModelConfig()
|
|
tp_rank = config.mapping.tp_rank
|
|
tp_size = config.mapping.tp_size
|
|
gpus_per_node = config.mapping.gpus_per_node
|
|
|
|
d_inner = d_model * expand
|
|
nheads = d_inner // head_dim
|
|
d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads
|
|
conv_dim = d_inner + 2 * n_groups * d_state
|
|
|
|
# TP
|
|
self.tp_conv_dim = conv_dim // tp_size
|
|
self.tp_d_inner = d_inner // tp_size
|
|
self.tp_nheads = nheads // tp_size
|
|
self.tp_ngroups = n_groups // tp_size
|
|
|
|
self.layer_idx = layer_idx
|
|
self.d_conv = d_conv
|
|
self.d_state = d_state
|
|
self.chunk_size = chunk_size
|
|
self.delta_rank = delta_rank
|
|
self.delta_softplus = delta_softplus
|
|
self.remove_padding = remove_padding
|
|
self.is_mamba2 = True
|
|
self.apply_silu = apply_silu
|
|
|
|
# tp
|
|
self.tp_size = tp_size
|
|
self.tp_rank = tp_rank
|
|
|
|
# paged state parameters
|
|
self.slot_mapping = None
|
|
self.is_paged_state = False
|
|
|
|
# in_proj
|
|
self.in_proj = Linear(
|
|
d_model,
|
|
d_in_proj,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_rank=tp_rank,
|
|
tensor_parallel_size=tp_size,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
gpus_per_node=gpus_per_node),
|
|
quant_config=config.get_quant_config(),
|
|
)
|
|
|
|
# conv1d, reuse Linear to store weights since it has support for TP > 1 already
|
|
self.conv1d = Linear(
|
|
d_conv,
|
|
conv_dim,
|
|
bias=conv_bias,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_rank=tp_rank,
|
|
tensor_parallel_size=tp_size,
|
|
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
|
gpus_per_node=gpus_per_node),
|
|
quant_config=config.get_quant_config(),
|
|
skip_create_weights=config.skip_create_weights,
|
|
)
|
|
|
|
# A
|
|
self.A = nn.Parameter(
|
|
torch.empty(self.tp_nheads,
|
|
dtype=torch.float32,
|
|
requires_grad=False))
|
|
|
|
# D
|
|
self.D = nn.Parameter(
|
|
torch.empty(self.tp_nheads,
|
|
dtype=torch.float32,
|
|
requires_grad=False))
|
|
|
|
# dt_bias
|
|
self.dt_bias = nn.Parameter(
|
|
torch.empty(self.tp_nheads,
|
|
dtype=torch.float32,
|
|
requires_grad=False))
|
|
|
|
# norm
|
|
self.norm = RMSNormGroup(d_inner,
|
|
n_groups,
|
|
eps=rms_norm_eps,
|
|
dtype=dtype,
|
|
config=config)
|
|
|
|
# out_proj
|
|
self.out_proj = Linear(d_inner,
|
|
d_model,
|
|
bias=bias,
|
|
dtype=dtype,
|
|
parallel_config=ParallelConfig(
|
|
tensor_parallel_rank=tp_rank,
|
|
tensor_parallel_size=tp_size,
|
|
tensor_parallel_mode=TensorParallelMode.ROW,
|
|
gpus_per_node=gpus_per_node),
|
|
quant_config=config.get_quant_config())
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
mamba_cache_params: MambaCacheParams,
|
|
) -> torch.Tensor:
|
|
|
|
# calculate split size
|
|
num_contexts = attn_metadata.num_contexts
|
|
num_generations = attn_metadata.seq_lens.shape[0] - num_contexts
|
|
sum_seq = torch.cumsum(attn_metadata.seq_lens, dim=0)
|
|
split_ctx = sum_seq[num_contexts - 1] if num_contexts > 0 else 0
|
|
split_gen = sum_seq[-1] - split_ctx
|
|
split_size = [split_ctx, split_gen]
|
|
|
|
# make the split
|
|
split_indices = torch.split(mamba_cache_params.indices,
|
|
[num_contexts, num_generations])
|
|
split_seq_lens = torch.split(attn_metadata.seq_lens,
|
|
[num_contexts, num_generations])
|
|
split_hidden_states = torch.split(hidden_states, split_size)
|
|
|
|
# a batch can have either:
|
|
# * only context requests
|
|
# * only generation requests
|
|
# * both context and generation requests
|
|
# req_type = 0 -> context
|
|
# req_type = 1 -> generation
|
|
batch = None
|
|
# both context and generation requests
|
|
if num_contexts > 0 and num_generations > 0:
|
|
batch = [0, 1]
|
|
# only context requests
|
|
elif num_contexts > 0:
|
|
batch = [0]
|
|
# only generation requests
|
|
elif num_generations > 0:
|
|
batch = [1]
|
|
|
|
out = []
|
|
for req_type in batch:
|
|
|
|
# read conv and ssm states
|
|
split_batch = split_indices[req_type].tolist()
|
|
conv_states = mamba_cache_params.conv_states[
|
|
self.layer_idx][split_batch]
|
|
ssm_states = mamba_cache_params.ssm_states[
|
|
self.layer_idx][split_batch]
|
|
|
|
# host request types
|
|
host_request_types = (torch.zeros_like(split_seq_lens[req_type])
|
|
if req_type == 0 else torch.ones_like(
|
|
split_seq_lens[req_type]))
|
|
|
|
# last token ids
|
|
last_token_ids = torch.cumsum(split_seq_lens[req_type],
|
|
dim=0,
|
|
dtype=torch.int32).cuda()
|
|
|
|
# in_proj
|
|
zxbcdt = self.in_proj(split_hidden_states[req_type])
|
|
|
|
conv_weight = self.conv1d.weight.unsqueeze(1).permute(
|
|
1, 2, 0).contiguous()
|
|
xbc, conv_states_out = torch.ops.trtllm.mamba_conv1d(
|
|
zxbcdt,
|
|
conv_weight,
|
|
self.conv1d.bias,
|
|
conv_states,
|
|
host_request_types,
|
|
last_token_ids,
|
|
split_seq_lens[req_type],
|
|
self.slot_mapping,
|
|
self.tp_conv_dim,
|
|
self.d_conv,
|
|
self.tp_d_inner, # pre_stride
|
|
self.tp_nheads, # post_stride
|
|
self.remove_padding,
|
|
self.apply_silu,
|
|
self.is_paged_state,
|
|
)
|
|
|
|
# selective scan
|
|
y, ssm_states_out = torch.ops.trtllm.selective_scan(
|
|
xbc,
|
|
ssm_states,
|
|
zxbcdt,
|
|
self.dt_bias,
|
|
self.A,
|
|
xbc,
|
|
self.D,
|
|
host_request_types,
|
|
last_token_ids,
|
|
zxbcdt,
|
|
split_seq_lens[req_type],
|
|
self.slot_mapping,
|
|
self.tp_d_inner,
|
|
self.d_state,
|
|
self.tp_nheads,
|
|
self.tp_ngroups,
|
|
self.chunk_size,
|
|
self.delta_rank,
|
|
self.delta_softplus,
|
|
self.remove_padding,
|
|
self.is_mamba2,
|
|
self.is_paged_state,
|
|
)
|
|
|
|
# group norm
|
|
y = self.norm(y)
|
|
|
|
# out_proj
|
|
y = self.out_proj(y)
|
|
|
|
# append output
|
|
out.append(y)
|
|
|
|
# update conv and ssm states
|
|
for i, idx in enumerate(split_batch):
|
|
mamba_cache_params.conv_states[self.layer_idx][idx].copy_(
|
|
conv_states_out[i])
|
|
mamba_cache_params.ssm_states[self.layer_idx][idx].copy_(
|
|
ssm_states_out[i])
|
|
|
|
out = torch.cat(out, dim=0)
|
|
return out
|