mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
# Copyright 2024 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Utils for positional embeddings (including RoPE).
|
|
"""
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
_MAX_WAVELENGTH = 10_000
|
|
|
|
|
|
def add_positional_embedding(
|
|
input_embedding: jax.Array,
|
|
position: int,
|
|
max_wavelength: int = _MAX_WAVELENGTH,
|
|
) -> jax.Array:
|
|
"""Adds positional embeddings to input embeddings."""
|
|
embed_dim = input_embedding.shape[-1]
|
|
num_timescales = embed_dim // 2
|
|
log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum(
|
|
jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
|
|
inv_timescales = jnp.exp(
|
|
jnp.arange(num_timescales, dtype=jnp.float32) *
|
|
-log_timescale_increment)
|
|
scaled_time = position * inv_timescales
|
|
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)])
|
|
signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]])
|
|
position_embedding = signal.astype(jnp.float32)
|
|
|
|
return input_embedding + position_embedding
|
|
|
|
|
|
def _rotary_embed(
|
|
inputs: jax.Array, # [B, 1, H, D]
|
|
position: jax.Array, # [B,]
|
|
head_dim: int,
|
|
max_wavelength: int = _MAX_WAVELENGTH,
|
|
) -> jax.Array:
|
|
"""Helper for RoPE."""
|
|
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
|
|
timescale = max_wavelength**fraction
|
|
timescale = timescale[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
|
|
|
|
sinusoid_inp = position[:, jnp.newaxis, jnp.newaxis,
|
|
jnp.newaxis] / timescale
|
|
sin = jnp.sin(sinusoid_inp)
|
|
cos = jnp.cos(sinusoid_inp)
|
|
|
|
first_half, second_half = jnp.split(inputs, 2, axis=-1)
|
|
first_part = first_half * cos - second_half * sin
|
|
second_part = second_half * cos + first_half * sin
|
|
|
|
return jnp.concatenate([first_part, second_part], axis=-1)
|
|
|
|
|
|
def apply_rope(
|
|
inputs: jax.Array,
|
|
position: int,
|
|
head_dim: int,
|
|
max_wavelength: int = _MAX_WAVELENGTH,
|
|
) -> jax.Array:
|
|
"""Applies RoPE."""
|
|
batch_size, seq_length = inputs.shape[0:2]
|
|
|
|
position = jnp.broadcast_to(position, [batch_size])[:, jnp.newaxis]
|
|
prefix_position = jnp.arange(seq_length, dtype=jnp.int32)
|
|
prefix_position = (position - jnp.flip(prefix_position)[jnp.newaxis, :]
|
|
) # [B, seq_len]
|
|
prefix_position = jnp.where(prefix_position < 0,
|
|
jnp.zeros_like(prefix_position),
|
|
prefix_position).reshape((batch_size, ))
|
|
|
|
output = _rotary_embed(
|
|
inputs,
|
|
position=prefix_position,
|
|
head_dim=head_dim,
|
|
max_wavelength=max_wavelength,
|
|
)
|
|
|
|
return output
|