mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
# Copyright 2024 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Base layers."""
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from flax import linen as nn
|
|
|
|
|
|
class Einsum(nn.Module):
|
|
shape: tuple[int, ...]
|
|
|
|
@nn.compact
|
|
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
|
|
w = self.param('w', nn.initializers.zeros_init(), self.shape)
|
|
return jnp.einsum(eqn, x, w)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1]))
|
|
var = jnp.mean(jnp.square(x), axis=-1, keepdims=True)
|
|
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06)))
|
|
normed_inputs = normed_inputs * (1 + scale)
|
|
return normed_inputs
|