mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
"""Gemma transformer."""
|
|
|
|
import dataclasses
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from flax import linen as nn
|
|
|
|
from . import layers, modules
|
|
from . import params as params_lib
|
|
|
|
Cache = dict[str, modules.LayerCache]
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TransformerConfig:
|
|
"""Configuration for the Gemma transformer."""
|
|
|
|
num_layers: int
|
|
num_embed: int
|
|
embed_dim: int
|
|
hidden_dim: int
|
|
num_heads: int
|
|
head_dim: int
|
|
num_kv_heads: int
|
|
|
|
@classmethod
|
|
def from_params(cls, params: params_lib.Params,
|
|
num_embed: int) -> 'TransformerConfig':
|
|
"""Creates a TransformerConfig from loaded parameters."""
|
|
num_layers = (max([
|
|
int(k.split('_')[1])
|
|
for k in params['transformer'].keys() if 'layer_' in k
|
|
]) + 1)
|
|
hidden_dim, embed_dim = (
|
|
params['transformer']['layer_0']['mlp']['linear'].shape)
|
|
num_heads, head_dim, _ = (params['transformer']['layer_0']['attn']
|
|
['attn_vec_einsum']['w'].shape)
|
|
use_qkv_einsum = 'qkv_einsum' in params['transformer']['layer_0'][
|
|
'attn']
|
|
if use_qkv_einsum:
|
|
num_kv_heads = num_heads
|
|
else:
|
|
num_kv_heads = params['transformer']['layer_0']['attn'][
|
|
'kv_einsum']['w'].shape[1]
|
|
return cls(
|
|
num_layers=num_layers,
|
|
num_embed=num_embed,
|
|
embed_dim=embed_dim,
|
|
hidden_dim=hidden_dim,
|
|
num_heads=num_heads,
|
|
head_dim=head_dim,
|
|
num_kv_heads=num_kv_heads,
|
|
)
|
|
|
|
|
|
def init_cache(config: TransformerConfig, cache_size: int,
|
|
batch_size: int) -> Cache:
|
|
"""Initializes a new Transformer cache."""
|
|
return {
|
|
f'layer_{i}':
|
|
modules.init_layer_cache(cache_size, config.num_heads, config.head_dim,
|
|
batch_size)
|
|
for i in range(config.num_layers)
|
|
}
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
"""Gemma transformer."""
|
|
|
|
config: TransformerConfig
|
|
|
|
def setup(self):
|
|
self.embedder = modules.Embedder(
|
|
vocab_size=self.config.num_embed,
|
|
embed_dim=self.config.embed_dim,
|
|
)
|
|
self.blocks = [
|
|
modules.Block(
|
|
name=f'layer_{i}',
|
|
num_heads=self.config.num_heads,
|
|
num_kv_heads=self.config.num_kv_heads,
|
|
embed_dim=self.config.embed_dim,
|
|
head_dim=self.config.head_dim,
|
|
hidden_dim=self.config.hidden_dim,
|
|
) for i in range(self.config.num_layers)
|
|
]
|
|
self.final_norm = layers.RMSNorm()
|
|
|
|
def __call__(
|
|
self,
|
|
last_tokens: jax.Array, # [B,]
|
|
current_token_position: int,
|
|
cache: Cache,
|
|
attention_mask: jax.Array, # [B, 1, L]
|
|
time_step: int,
|
|
) -> tuple[jax.Array, Cache]:
|
|
input_emb = self.embedder.encode(last_tokens)
|
|
x = jnp.expand_dims(input_emb, axis=1) # adding temporal dimension
|
|
|
|
for i, block in enumerate(self.blocks):
|
|
layer_name = f'layer_{i}'
|
|
cache[layer_name], x = block(
|
|
x,
|
|
current_token_position,
|
|
cache[layer_name],
|
|
attention_mask,
|
|
time_step,
|
|
)
|
|
|
|
x = self.final_norm(x)
|
|
logits = self.embedder.decode(x)
|
|
|
|
return logits, cache
|