mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: [AutoDeploy] update rope matcher with minor variants (Deepseek) (#3638)
* add docstring to summarize current rope support Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor: replace call_method, adjust inserting order of cos_sin_cache calculation node Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * add unit test for triton rope and ds rope Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * update rope matcher to match DS RoPE, add custom op for reference, add unit test case Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * cache cos[pos_idx].unsqueeze and sin[pos_idxs].unsqueeze Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor doc update Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * separate pattern matching and optimization for explicit and complex rope + minor updates Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * clean rope impl in repo Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * replace fused_flattened_mla_with_cache's rope impl with torch_apply_rope_with_qk_interleaving, update unit test Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * separate layout infer and transpose to a new transformation Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * update rope_with_explicit_freqs and rope_with_input_interleaved to expose unsqueeze_dim and support match_rope_layout, add unit tests Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * solve merge conflict in transform.py, need to fix optimize_rope with cuda graph capture Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor clean up after rebase Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> * fix pre-commit Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * support map to bnsd layout and infer unsqueeze_dim from op Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * fix cos/sin not the same across prompts in the same batch issue when mapping to flashinfer op Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * fix for unit test Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * fix custom op input/output node ordering issue for DeepSeek V3 rope Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * clean code Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * move flattening of cos_sin_cache to the graph, update flashinfer op docstring and test Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * debug transform unit test failure Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com> Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
f5b6d453aa
commit
bce281d592
@ -10,4 +10,5 @@ from .mla import *
|
||||
from .quant import *
|
||||
from .rope import *
|
||||
from .torch_attention import *
|
||||
from .torch_rope import *
|
||||
from .triton_attention import *
|
||||
|
||||
@ -21,10 +21,11 @@ def apply_rope_with_input_pos_flashinfer(
|
||||
Tensors of shape [batch, seq_len, n_head, head_dim] (or a 3D variant)
|
||||
in half precision. Note: head_dim must be a multiple of 64.
|
||||
- position_ids (torch.Tensor):
|
||||
Precomputed tensor of positional indices; it is shared across calls in the graph.
|
||||
Precomputed tensor of positional indices indicating idx in cos_sin_cache for each token;
|
||||
Shape [batch, seq_len] or [batch * seq_len]
|
||||
- cos_sin_cache (torch.Tensor):
|
||||
Precomputed fused tensor created by concatenating the first half of the cosine and sine
|
||||
components derived from the inv_freq.
|
||||
components derived from the inv_freq. Shape [max_seq_len, head_dim]. Must be float32.
|
||||
- is_neox (bool):
|
||||
Flag to indicate whether to invoke the FlashInfer kernel in Neox mode.
|
||||
|
||||
@ -35,14 +36,13 @@ def apply_rope_with_input_pos_flashinfer(
|
||||
"""
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
batch_size, seq_len = q_shape[:2]
|
||||
|
||||
head_dim = cos_sin_cache.shape[-1]
|
||||
|
||||
q_flat = q.view(batch_size * seq_len, -1)
|
||||
k_flat = k.view(batch_size * seq_len, -1)
|
||||
position_ids = position_ids.view(-1).to(q.device)
|
||||
num_nnz = position_ids.shape[0]
|
||||
|
||||
position_ids = position_ids.to(q.device)
|
||||
q_flat = q.view(num_nnz, -1)
|
||||
k_flat = k.view(num_nnz, -1)
|
||||
|
||||
query_rotated_flash, key_rotated_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
|
||||
position_ids, q_flat, k_flat, head_dim, cos_sin_cache, is_neox=is_neox
|
||||
@ -60,4 +60,4 @@ def apply_rope_with_input_pos_flashinfer_fake(
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return q, k
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
|
||||
@ -17,7 +17,6 @@ from .attention_interface import (
|
||||
PrepareMetadataCallable,
|
||||
SequenceInfo,
|
||||
)
|
||||
from .torch_attention import apply_rotary_pos_emb_ds
|
||||
from .triton_attention import _flattened_context_mha, _generate_mha
|
||||
|
||||
Constant = Union[int, float, str, None]
|
||||
@ -74,26 +73,38 @@ def fused_flattened_mla_with_cache(
|
||||
# Apply RoPE
|
||||
if cos_sin_stacked.numel() > 0:
|
||||
# Extract cos and sin from freqs_cis
|
||||
cos = cos_sin_stacked[0, ...]
|
||||
sin = cos_sin_stacked[1, ...]
|
||||
cos_base = cos_sin_stacked[0, ...]
|
||||
sin_base = cos_sin_stacked[1, ...]
|
||||
|
||||
# TODO: Use triton kernels for RoPE
|
||||
# TODO: Add yarn support
|
||||
for idx in range(seq_len.shape[0]):
|
||||
(
|
||||
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
) = apply_rotary_pos_emb_ds(
|
||||
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
for i in range(seq_len.shape[0]):
|
||||
start = seq_start[i]
|
||||
length = seq_len[i]
|
||||
|
||||
# build position_ids
|
||||
if s == 1:
|
||||
idx = (input_pos[i] + length - 1).item()
|
||||
pos_ids = torch.tensor(idx, device=cos_base.device)
|
||||
else:
|
||||
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
|
||||
|
||||
cos = cos_base[pos_ids] # [..., 1, head_dim]
|
||||
sin = sin_base[pos_ids]
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
sin,
|
||||
torch.arange(input_pos[idx] + seq_len[idx])[-1]
|
||||
if s == 1
|
||||
else torch.arange(input_pos[idx] + seq_len[idx]),
|
||||
-2,
|
||||
)
|
||||
|
||||
q_pe[start : start + length] = q_rot
|
||||
k_pe[start : start + length] = k_rot
|
||||
|
||||
# Create query_states, key_states
|
||||
query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d]
|
||||
key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d]
|
||||
|
||||
@ -138,92 +138,6 @@ def bsnd_grouped_sdpa_fake(
|
||||
return torch.empty_like(query.contiguous())
|
||||
|
||||
|
||||
# Function to apply rotary positional embeddings (RoPE)
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
seq_len: int,
|
||||
head_dim: int,
|
||||
rope_theta: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Apply rotary positional embeddings to query and key tensors.
|
||||
Args:
|
||||
q: Query tensor of shape [batch, n_heads, seq_len, head_dim]
|
||||
k: Key tensor of shape [batch, n_kv_heads, seq_len, head_dim]
|
||||
seq_len: Sequence length
|
||||
head_dim: Dimension of each head
|
||||
rope_theta: Base value for RoPE (default 10000.0)
|
||||
rope_scale: Scaling factor for positions (default 1.0)
|
||||
Returns:
|
||||
Tuple of transformed query and key tensors
|
||||
"""
|
||||
device = q.device
|
||||
original_dtype = q.dtype
|
||||
|
||||
# Apply default values if None
|
||||
theta = 10000.0 if rope_theta is None else rope_theta
|
||||
scale = 1.0 if rope_scale is None else rope_scale
|
||||
|
||||
# Generate position indices
|
||||
position = torch.arange(seq_len, device=device).float()
|
||||
# Apply scaling factor to positions if provided
|
||||
if scale != 1.0:
|
||||
position = position / scale
|
||||
|
||||
# Create the frequency matrix - ensure stable computation in float32
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
||||
# Compute the product of positions and frequencies
|
||||
# Shape: [seq_len, head_dim/2]
|
||||
freqs = torch.outer(position, inv_freq)
|
||||
|
||||
# Compute the rotation matrix elements: cos and sin
|
||||
# Shape: [seq_len, head_dim/2]
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
# Ensure stable computation of sin/cos in float32
|
||||
cos = torch.cos(emb).to(dtype=torch.float32)
|
||||
sin = torch.sin(emb).to(dtype=torch.float32)
|
||||
|
||||
# Reshape for broadcasting
|
||||
# Shape: [1, 1, seq_len, head_dim]
|
||||
cos = cos.view(1, 1, seq_len, head_dim)
|
||||
sin = sin.view(1, 1, seq_len, head_dim)
|
||||
|
||||
# Always compute in float32 for numerical stability
|
||||
q_float = q.to(dtype=torch.float32)
|
||||
k_float = k.to(dtype=torch.float32)
|
||||
|
||||
# For the even indices of the dimension
|
||||
q_embed_even = q_float[..., 0::2]
|
||||
q_embed_odd = q_float[..., 1::2]
|
||||
k_embed_even = k_float[..., 0::2]
|
||||
k_embed_odd = k_float[..., 1::2]
|
||||
|
||||
# Apply the rotation using the identities:
|
||||
# q' = q * cos + rotate(q) * sin
|
||||
# k' = k * cos + rotate(k) * sin
|
||||
# where rotate(x) swaps the even and odd dimensions and negates the odd dimensions
|
||||
q_rotated = torch.cat(
|
||||
[
|
||||
q_embed_even * cos[..., 0::2] - q_embed_odd * sin[..., 0::2],
|
||||
q_embed_odd * cos[..., 1::2] + q_embed_even * sin[..., 1::2],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
k_rotated = torch.cat(
|
||||
[
|
||||
k_embed_even * cos[..., 0::2] - k_embed_odd * sin[..., 0::2],
|
||||
k_embed_odd * cos[..., 1::2] + k_embed_even * sin[..., 1::2],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Convert back to the original dtype
|
||||
return q_rotated.to(dtype=original_dtype), k_rotated.to(dtype=original_dtype)
|
||||
|
||||
|
||||
def update_kv_cache(
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
@ -248,46 +162,6 @@ def update_kv_cache(
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
@torch.inference_mode()
|
||||
def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
position_ids (`torch.Tensor`):
|
||||
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
||||
used to pass offsetted position ids when working with a KV-cache.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
|
||||
q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q)
|
||||
k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@torch.library.custom_op("attention::fused_mla_ref", mutates_args=())
|
||||
def fused_mla_ref(
|
||||
q_nope: torch.Tensor,
|
||||
@ -325,23 +199,33 @@ def fused_mla_ref(
|
||||
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
|
||||
|
||||
if freqs_cis is not None:
|
||||
cos = freqs_cis[0, ...]
|
||||
sin = freqs_cis[1, ...]
|
||||
for idx in range(seq_len.shape[0]):
|
||||
(
|
||||
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
) = apply_rotary_pos_emb_ds(
|
||||
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
|
||||
cos_base = freqs_cis[0, ...]
|
||||
sin_base = freqs_cis[1, ...]
|
||||
for i in range(seq_len.shape[0]):
|
||||
start = seq_start[i]
|
||||
length = seq_len[i]
|
||||
if q_len == 1:
|
||||
idx = (input_pos[i] + length - 1).item()
|
||||
pos_ids = torch.tensor(idx, device=cos_base.device)
|
||||
else:
|
||||
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
|
||||
|
||||
cos = cos_base[pos_ids] # [..., 1, head_dim]
|
||||
sin = sin_base[pos_ids]
|
||||
q_slice = q_pe[start : start + length]
|
||||
k_slice = k_pe[start : start + length]
|
||||
|
||||
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q_slice,
|
||||
k_slice,
|
||||
cos,
|
||||
sin,
|
||||
torch.arange(input_pos[idx] + seq_len[idx])[-1]
|
||||
if q_len == 1
|
||||
else torch.arange(input_pos[idx] + seq_len[idx]),
|
||||
-2,
|
||||
)
|
||||
|
||||
q_pe[start : start + length] = q_rot
|
||||
k_pe[start : start + length] = k_rot
|
||||
|
||||
query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d]
|
||||
query_states[..., :qk_nope_head_dim] = q_nope
|
||||
query_states[..., qk_nope_head_dim:] = q_pe
|
||||
@ -454,7 +338,9 @@ def fused_mla(
|
||||
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
kv_seq_len = value_states.shape[-2]
|
||||
|
||||
q_pe, k_pe = apply_rotary_pos_emb_ds(q_pe, k_pe, cos, sin, position_ids)
|
||||
cos = cos[position_ids]
|
||||
sin = sin[position_ids]
|
||||
q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
|
||||
|
||||
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
|
||||
query_states[:, :, :, :qk_nope_head_dim] = q_nope
|
||||
|
||||
100
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py
Normal file
100
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=())
|
||||
def torch_apply_rope_with_explicit_cos_sin(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reference PyTorch implementation of HF-style RoPE:
|
||||
- Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and
|
||||
[B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D]
|
||||
- Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim].
|
||||
"""
|
||||
# in HF, cos/sin tensor are passed in as x.dtype, this is to double ensure
|
||||
cos = cos.type_as(q)
|
||||
sin = sin.type_as(q)
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@torch_apply_rope_with_explicit_cos_sin.register_fake
|
||||
def torch_apply_rope_with_explicit_cos_sin_fake(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=())
|
||||
def torch_apply_rope_with_complex_freqs(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # shape [B, S, head_dim//2]
|
||||
unsqueeze_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reference PyTorch implementation of interleaved (complex) RoPE:
|
||||
- Input layout: [B, S, N, D] (interleaved)
|
||||
- Frequencies are combined into a single complex-valued tensor `freqs_cis`
|
||||
of shape [B, S, head_dim // 2].
|
||||
"""
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
freqs = freqs_cis.unsqueeze(unsqueeze_dim)
|
||||
xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
@torch_apply_rope_with_complex_freqs.register_fake
|
||||
def torch_apply_rope_with_complex_freqs_fake(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # shape [B, S, head_dim//2]
|
||||
unsqueeze_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(xq), torch.empty_like(xk)
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=())
|
||||
def torch_apply_rope_with_qk_interleaving(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
DeepSeek-style RoPE: interleaves Q/K channels and returns rotated (q_embed, k_embed).
|
||||
- Input layout: [B, S, N, D] or [B*S, N, D] or [B, N, S, D]
|
||||
- Frequencies are provided as separate `cos` and `sin` tensors of shape
|
||||
[B, S, 1, D] or [B*S, 1, D] or [B, 1, S, D] matching input shape.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
# Rewrite below code to accept 3D input:
|
||||
# b, h, s, d = q.shape
|
||||
# q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
# b, h, s, d = k.shape
|
||||
# k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q)
|
||||
k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@torch_apply_rope_with_qk_interleaving.register_fake
|
||||
def torch_apply_rope_with_qk_interleaving_fake(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.empty_like(q), torch.empty_like(k)
|
||||
@ -1,12 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
@ -36,7 +36,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
|
||||
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
|
||||
if not common_ancessor2:
|
||||
continue
|
||||
selected_experts = _bfs(
|
||||
selected_experts = bfs(
|
||||
common_ancessor2,
|
||||
lambda node: is_op(node, torch.ops.aten.one_hot),
|
||||
attr_next="all_input_nodes",
|
||||
@ -153,26 +153,6 @@ def _insert_fused_moe_ops(gm: GraphModule):
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def _bfs(
|
||||
node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None
|
||||
) -> Node:
|
||||
queue = [node]
|
||||
visited = set()
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if boundary is not None and cur_node == boundary:
|
||||
continue # Skip the boundary node.
|
||||
if target(cur_node):
|
||||
return cur_node
|
||||
for next_node in getattr(cur_node, attr_next):
|
||||
if boundary is not None and next_node == boundary:
|
||||
continue # Do not expand past the boundary.
|
||||
if next_node not in visited:
|
||||
visited.add(next_node)
|
||||
queue.append(next_node)
|
||||
raise RuntimeError(f"Could not find node with target condition {target}.")
|
||||
|
||||
|
||||
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
|
||||
@ -326,7 +306,7 @@ def _find_final_hidden_state_node(
|
||||
For each expert output node (from the expert compute pattern), this function:
|
||||
1. Retrieves a multiplication node from its users.
|
||||
2. Extracts the second argument from the multiplication node (assumed to be the index node).
|
||||
3. Uses a BFS (via _bfs) to locate the subsequent index_add_ node (guarded by the end_boundary).
|
||||
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
|
||||
|
||||
After collecting all such index_add_ nodes, the final hidden state node is determined
|
||||
as the one that is not used by any of the other index_add_ nodes.
|
||||
@ -346,7 +326,7 @@ def _find_final_hidden_state_node(
|
||||
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
|
||||
return None
|
||||
index_node = mul_node.args[1]
|
||||
index_add_node = _bfs(
|
||||
index_add_node = bfs(
|
||||
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
|
||||
)
|
||||
if not index_add_node:
|
||||
@ -412,7 +392,7 @@ def _remove_dead_inplace_nodes_in_region(
|
||||
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
|
||||
|
||||
try:
|
||||
node_to_remove = _bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
|
||||
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
|
||||
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
|
||||
graph.erase_node(node_to_remove)
|
||||
return True
|
||||
|
||||
@ -1,86 +1,141 @@
|
||||
"""
|
||||
This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used
|
||||
to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`).
|
||||
|
||||
Supported RoPE variants:
|
||||
|
||||
1. Explicit Cos/Sin Multiplication (HF-style, e.g., LLaMA, Mixtral, Qwen)
|
||||
- Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and
|
||||
[B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D]
|
||||
- Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim].
|
||||
- Source code:
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
2. Complex Multiplication (GPTJ/Llama-stack-style, interleaved)
|
||||
- Input layout: [B, S, N, D] (interleaved)
|
||||
- Frequencies are combined into a single complex-valued tensor `freqs_cis` of shape [B, S, head_dim // 2].
|
||||
- Source code:
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
Supported Minor variants:
|
||||
- DeepSeekV3: reshape + transpose before applying RoPE.
|
||||
dynamic position-based updates to frequency cache.
|
||||
|
||||
TODO: Support other variants:
|
||||
- Phi-4: rotary applied only to part of the hidden dimension (q_rot, q_pass split).
|
||||
- LLaMA4 Vision: 2D rotary frequencies constructed from image patches.
|
||||
"""
|
||||
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Optional
|
||||
from typing import Any, DefaultDict, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op
|
||||
from ...utils.node_utils import bfs, extract_output_tuple, identify_regions_between_residuals, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_rope_v1(gm: GraphModule) -> GraphModule:
|
||||
def match_explicit_rope(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Identify and replace legacy RoPE subgraphs (explicit cos/sin multiplication pattern):
|
||||
Identify and replace RoPE subgraphs (explicit cos/sin multiplication pattern):
|
||||
|
||||
output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
|
||||
- HF-style: output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
|
||||
- DS-style: requires interleaving Q/K before the cos/sin mul
|
||||
|
||||
If exactly two such branches (query and key) are detected within each region, they're replaced
|
||||
by a call to `torch.ops.rope.flashinfer`.
|
||||
by a call to `rope::torch_apply_rope_with_qk_interleaving` or
|
||||
`rope::torch_apply_rope_with_explicit_cos_sin` respectively.
|
||||
"""
|
||||
ad_logger.info("Match explicit(HF) style RoPE")
|
||||
graph = gm.graph
|
||||
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
|
||||
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
|
||||
rope_position_ids_cache: Dict[str, Node] = {}
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
matches = []
|
||||
matches = [] # list of (match_info, is_ds)
|
||||
node = start_boundary
|
||||
while node != end_boundary:
|
||||
if is_op(node, torch.ops.aten.add):
|
||||
match_info = _match_rotary_subpattern_V1(node)
|
||||
if match_info:
|
||||
matches.append(match_info)
|
||||
explicit = _match_explicit_rope_subpattern(node)
|
||||
if explicit is not None:
|
||||
raw = explicit["raw_input"]
|
||||
# check if this raw is result of DS interleave
|
||||
inter = _match_input_interleave_pattern(raw)
|
||||
if inter is not None:
|
||||
ds_match = {
|
||||
"raw_input": inter["interleaved"],
|
||||
"unsqueeze_cos": explicit["unsqueeze_cos"],
|
||||
"unsqueeze_sin": explicit["unsqueeze_sin"],
|
||||
"add_node": explicit["add_node"],
|
||||
}
|
||||
matches.append((ds_match, True))
|
||||
else:
|
||||
matches.append((explicit, False))
|
||||
node = node.next
|
||||
|
||||
if not matches:
|
||||
continue
|
||||
if len(matches) != 2:
|
||||
raise RuntimeError(
|
||||
ad_logger.warning(
|
||||
f"Expected exactly 2 legacy RoPE branches between {start_boundary} and {end_boundary}, "
|
||||
f"found {len(matches)}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Assume the first matched branch is query (q), second is key (k).
|
||||
# This assumption is based on the default ordering in the exported graph,
|
||||
# since node naming conventions don't reliably indicate q/k branches.
|
||||
q_match, k_match = matches
|
||||
_process_rope_v1(
|
||||
graph,
|
||||
q_match,
|
||||
k_match,
|
||||
start_boundary,
|
||||
rope_flash_cache,
|
||||
rope_position_ids_cache,
|
||||
)
|
||||
(q_match, q_is_ds), (k_match, k_is_ds) = matches
|
||||
if q_is_ds != k_is_ds:
|
||||
ad_logger.warning("Mismatched RoPE types between q and k branches")
|
||||
continue
|
||||
|
||||
if q_is_ds:
|
||||
_process_input_interleave_rope(graph, q_match, k_match)
|
||||
else:
|
||||
_process_explicit_rope(graph, q_match, k_match, start_boundary)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def match_rope_v2(gm: GraphModule) -> GraphModule:
|
||||
def match_complex_rope(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Identify and replace RoPE subgraphs using complex multiplication pattern:
|
||||
|
||||
output = type_as(flatten(view_as_real(mul(view_as_complex(reshape(to_dtype(x))), unsqueeze(freqs_cis, 2)))), x)
|
||||
|
||||
If exactly two such branches (query and key) are detected within each region, they're replaced
|
||||
by a call to `torch.ops.rope.flashinfer`.
|
||||
by a call to `torch.ops.rope.torch_apply_rope_with_complex_freqs`.
|
||||
"""
|
||||
ad_logger.info("Match Complex style RoPE")
|
||||
graph = gm.graph
|
||||
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
|
||||
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
|
||||
rope_position_ids_cache: Dict[str, Node] = {}
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
matches = []
|
||||
node = start_boundary
|
||||
while node != end_boundary:
|
||||
if is_op(node, torch.ops.aten.type_as):
|
||||
match_info = _match_rotary_subpattern_V2(node)
|
||||
match_info = _match_complex_rope_subpattern(node)
|
||||
if match_info:
|
||||
matches.append(match_info)
|
||||
node = node.next
|
||||
@ -88,28 +143,344 @@ def match_rope_v2(gm: GraphModule) -> GraphModule:
|
||||
if not matches:
|
||||
continue
|
||||
if len(matches) != 2:
|
||||
raise RuntimeError(
|
||||
ad_logger.warning(
|
||||
f"Expected exactly 2 complex RoPE branches between {start_boundary} and {end_boundary}, "
|
||||
f"found {len(matches)}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Assume the first matched branch is query (q), second is key (k).
|
||||
# This assumption is based on the default ordering in the exported graph,
|
||||
# since node naming conventions don't reliably indicate q/k branches.
|
||||
q_match, k_match = matches
|
||||
_process_rope_v2(
|
||||
graph,
|
||||
q_match,
|
||||
k_match,
|
||||
rope_flash_cache,
|
||||
rope_position_ids_cache,
|
||||
)
|
||||
_process_complex_rope(graph, q_match, k_match)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
|
||||
def _get_default_unsqueeze_dim(op):
|
||||
schema = next(iter(op._schemas.values()))
|
||||
for a in schema.arguments:
|
||||
if a.name == "unsqueeze_dim" and a.has_default_value:
|
||||
return a.default_value
|
||||
raise RuntimeError(f"No default unsqueeze_dim on {op}")
|
||||
|
||||
|
||||
def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphModule:
|
||||
"""
|
||||
Match and transform input and output of rope ops to the layout specified to meet requirements of optimized ops.
|
||||
Supported layout is 'bsnd' (batch, seq, head, dim).
|
||||
"""
|
||||
supported = {"bsnd", "bnsd"}
|
||||
if expected_layout.lower() not in supported:
|
||||
ad_logger.warning(
|
||||
f"Unsupported RoPE layout '{expected_layout}'; expected '{supported}'. Skipping RoPE layout matching."
|
||||
)
|
||||
return gm
|
||||
|
||||
ad_logger.info(f"Match RoPE layout to {expected_layout}")
|
||||
|
||||
graph = gm.graph
|
||||
rope_ops = {
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
|
||||
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
|
||||
torch.ops.rope.torch_apply_rope_with_complex_freqs,
|
||||
}
|
||||
|
||||
need_transpose = False
|
||||
need_canonicalize_graph = False
|
||||
for node in graph.nodes:
|
||||
if not is_op(node, rope_ops):
|
||||
continue
|
||||
|
||||
rope_op = next(op for op in rope_ops if is_op(node, op))
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
q_node, k_node, freqs_node, *rest = node.args
|
||||
unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op)
|
||||
else:
|
||||
q_node, k_node, cos_node, sin_node, *rest = node.args
|
||||
unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op)
|
||||
|
||||
if unsq == 2:
|
||||
current_layout = "bsnd"
|
||||
elif unsq == 1:
|
||||
current_layout = "bnsd"
|
||||
else:
|
||||
ad_logger.warning(
|
||||
"Unsqueeze_dim is not one of [1, 2]. "
|
||||
"Unable to infer layout of q node. Skip layout matching"
|
||||
)
|
||||
continue
|
||||
|
||||
need_transpose = expected_layout.lower() != current_layout
|
||||
|
||||
if not need_transpose:
|
||||
continue
|
||||
|
||||
need_canonicalize_graph = True
|
||||
# retrieve q and k output node from node
|
||||
q_rope_old, k_rope_old = extract_output_tuple(node, 2)
|
||||
if q_rope_old is None or k_rope_old is None:
|
||||
ad_logger.warning(
|
||||
f"Failed to extract all two outputs from the explicit op, \
|
||||
get {q_rope_old}, {k_rope_old}, fail to match rope layout with {node} with"
|
||||
)
|
||||
continue
|
||||
|
||||
ad_logger.debug(
|
||||
f"Inferred RoPE input layout: '{current_layout}']Mapping layout to '{expected_layout}']"
|
||||
)
|
||||
with graph.inserting_before(node):
|
||||
q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2))
|
||||
k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2))
|
||||
q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,))
|
||||
k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,))
|
||||
|
||||
q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
|
||||
k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
|
||||
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
new_args = (
|
||||
q_for_op_contig,
|
||||
k_for_op_contig,
|
||||
freqs_node,
|
||||
2 if expected_layout.lower() == "bsnd" else 1,
|
||||
) # unsqueeze_dim updated
|
||||
else:
|
||||
new_args = (
|
||||
q_for_op_contig,
|
||||
k_for_op_contig,
|
||||
cos_node,
|
||||
sin_node,
|
||||
2 if expected_layout.lower() == "bsnd" else 1,
|
||||
) # unsqueeze_dim updated
|
||||
node.args = new_args
|
||||
|
||||
with graph.inserting_after(q_rope_old):
|
||||
q_rope_new = graph.call_function(torch.ops.aten.transpose, args=(q_rope_old, 1, 2))
|
||||
with graph.inserting_after(k_rope_old):
|
||||
k_rope_new = graph.call_function(torch.ops.aten.transpose, args=(k_rope_old, 1, 2))
|
||||
|
||||
# Preserve fake tensor in meta["val"] for the transposed inputs
|
||||
q_rope_new.meta["val"] = q_rope_old.meta["val"]
|
||||
q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2)
|
||||
k_rope_new.meta["val"] = k_rope_old.meta["val"]
|
||||
k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2)
|
||||
|
||||
q_rope_old.replace_all_uses_with(q_rope_new)
|
||||
k_rope_old.replace_all_uses_with(k_rope_new)
|
||||
q_rope_new.args = (q_rope_old, 1, 2)
|
||||
k_rope_new.args = (k_rope_old, 1, 2)
|
||||
|
||||
if need_canonicalize_graph:
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def optimize_rope(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Scan the FX graph and replace calls to the torch-reference RoPE ops with
|
||||
the optimized `rope::flashinfer` kernel.
|
||||
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
|
||||
and reuses those nodes when possible.
|
||||
"""
|
||||
ad_logger.info("RoPE optimization")
|
||||
graph = gm.graph
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
|
||||
rope_position_ids_cache: Dict[str, Node] = {}
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin):
|
||||
_optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache)
|
||||
elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
|
||||
_optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def _optimize_explicit(
|
||||
graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
|
||||
) -> None:
|
||||
# node.args may be (q, k, cos, sin) or (q, k, cos, sin, unsq)
|
||||
q_node, k_node, cos_node, sin_node, *rest = node.args
|
||||
# retrieve q and k output node from node
|
||||
q_rope_old, k_rope_old = extract_output_tuple(node, 2)
|
||||
if q_rope_old is None or k_rope_old is None:
|
||||
ad_logger.warning(
|
||||
f"Failed to extract all two outputs from the explicit op, \
|
||||
get {q_rope_old}, {k_rope_old}, fail to replace {node} with flashinfer rope"
|
||||
)
|
||||
return
|
||||
|
||||
# Sanity check on head_dim
|
||||
if not _validate_rope_inputs(q_node, k_node):
|
||||
return
|
||||
|
||||
# Sanity check that input layout is BSND (no transpose needed).
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
|
||||
ad_logger.warning(
|
||||
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
|
||||
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
|
||||
)
|
||||
return
|
||||
elif q_fake is not None:
|
||||
ad_logger.warning(
|
||||
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
|
||||
)
|
||||
return
|
||||
|
||||
head_dim = cos_node.meta["val"].shape[-1]
|
||||
half_head_dim = head_dim // 2
|
||||
|
||||
cache_key = (cos_node, sin_node)
|
||||
if cache_key in cache:
|
||||
fused_cos_sin_to = cache[cache_key]
|
||||
else:
|
||||
with graph.inserting_after(cos_node):
|
||||
cos_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
|
||||
)
|
||||
with graph.inserting_after(sin_node):
|
||||
sin_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
|
||||
)
|
||||
with graph.inserting_after(sin_prefix):
|
||||
fused_cos_sin = graph.call_function(
|
||||
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
|
||||
)
|
||||
with graph.inserting_after(q_node):
|
||||
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0))
|
||||
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1))
|
||||
with graph.inserting_after(_get_last_node([sym_batch, sym_seq])):
|
||||
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
|
||||
with graph.inserting_after(_get_last_node([bs_seq, fused_cos_sin])):
|
||||
fused_cos_sin_flat = graph.call_function(
|
||||
torch.ops.aten.view, args=(fused_cos_sin, (bs_seq, -1))
|
||||
)
|
||||
with graph.inserting_after(fused_cos_sin_flat):
|
||||
fused_cos_sin_to = graph.call_function(
|
||||
torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32)
|
||||
)
|
||||
cache[cache_key] = fused_cos_sin_to
|
||||
|
||||
with graph.inserting_before(node):
|
||||
position_ids = _get_position_ids(
|
||||
graph,
|
||||
q_node,
|
||||
batch_dim=0,
|
||||
seq_dim=1,
|
||||
rope_position_ids_cache=pos_cache,
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_node, k_node, position_ids, fused_cos_sin_to, True),
|
||||
)
|
||||
|
||||
with graph.inserting_after(flash_node):
|
||||
q_rope_new = graph.call_function(operator.getitem, args=(flash_node, 0))
|
||||
k_rope_new = graph.call_function(operator.getitem, args=(flash_node, 1))
|
||||
|
||||
q_rope_new.meta["val"] = q_rope_old.meta.get("val", None)
|
||||
k_rope_new.meta["val"] = k_rope_old.meta.get("val", None)
|
||||
|
||||
q_rope_old.replace_all_uses_with(q_rope_new)
|
||||
k_rope_old.replace_all_uses_with(k_rope_new)
|
||||
|
||||
graph.erase_node(q_rope_old)
|
||||
graph.erase_node(k_rope_old)
|
||||
|
||||
|
||||
def _optimize_complex(
|
||||
graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
|
||||
) -> None:
|
||||
q_node, k_node, inv_freq_node = node.args
|
||||
|
||||
# Sanity check on head_dim
|
||||
if not _validate_rope_inputs(q_node, k_node):
|
||||
return
|
||||
|
||||
# Sanity check that input layout is BSND (no transpose needed).
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
|
||||
ad_logger.warning(
|
||||
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
|
||||
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
|
||||
)
|
||||
return
|
||||
elif q_fake is not None:
|
||||
ad_logger.warning(
|
||||
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
|
||||
)
|
||||
return
|
||||
|
||||
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
|
||||
if inv_freq_node in cache:
|
||||
cos_sin_flash = cache[inv_freq_node]
|
||||
else:
|
||||
# Compute the fused cosine/sine cache.
|
||||
with graph.inserting_after(inv_freq_node):
|
||||
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
|
||||
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
|
||||
with graph.inserting_after(real_part):
|
||||
cos_sin_flash_3d = graph.call_function(
|
||||
torch.ops.aten.cat, args=((real_part, imag_part), -1)
|
||||
)
|
||||
with graph.inserting_after(q_node):
|
||||
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0))
|
||||
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1))
|
||||
with graph.inserting_after(_get_last_node([sym_batch, sym_seq])):
|
||||
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
|
||||
with graph.inserting_after(_get_last_node([bs_seq, cos_sin_flash_3d])):
|
||||
fused_cos_sin_flat = graph.call_function(
|
||||
torch.ops.aten.view, args=(cos_sin_flash_3d, (bs_seq, -1))
|
||||
)
|
||||
with graph.inserting_after(fused_cos_sin_flat):
|
||||
cos_sin_flash = graph.call_function(
|
||||
torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32)
|
||||
)
|
||||
cache[inv_freq_node] = cos_sin_flash
|
||||
|
||||
with graph.inserting_before(node):
|
||||
position_ids = _get_position_ids(
|
||||
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_node, k_node, position_ids, cos_sin_flash, False),
|
||||
)
|
||||
|
||||
flash_node.meta["val"] = node.meta.get("val", None)
|
||||
node.replace_all_uses_with(flash_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Detect DeepSeek-style interleave on Q/K:
|
||||
reshape(transpose(view(raw, [b,h,s,d//2,2]), 4, 3), [b,h,s,d])
|
||||
Returns:
|
||||
{"interleaved": raw_node} if matched, else None.
|
||||
"""
|
||||
if not is_op(node, torch.ops.aten.reshape):
|
||||
return None
|
||||
transpose_node = node.args[0]
|
||||
if not is_op(transpose_node, torch.ops.aten.transpose):
|
||||
return None
|
||||
view_node = transpose_node.args[0]
|
||||
if not is_op(view_node, torch.ops.aten.view):
|
||||
return None
|
||||
raw_node = view_node.args[0]
|
||||
if not isinstance(raw_node, Node):
|
||||
return None
|
||||
return {"interleaved": raw_node}
|
||||
|
||||
|
||||
def _match_explicit_rope_subpattern(add_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Given an aten.add.Tensor node that is expected to compute:
|
||||
output = (raw_input * unsqueeze(cos)) + (rotate_half(raw_input) * unsqueeze(sin))
|
||||
@ -195,7 +566,7 @@ def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
|
||||
}
|
||||
|
||||
|
||||
def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]:
|
||||
def _match_complex_rope_subpattern(type_as_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Given a type_as node, this function inspects the graph
|
||||
structure and returns a dictionary with:
|
||||
@ -245,9 +616,10 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Verify that the unsqueeze is performed along dimension 2.
|
||||
if not (len(unsqueeze_node.args) >= 2 and unsqueeze_node.args[1] == 2):
|
||||
if not (len(unsqueeze_node.args) >= 2):
|
||||
return None
|
||||
unsqueeze_dim = unsqueeze_node.args[1]
|
||||
|
||||
inv_freq_candidate = unsqueeze_node.args[0]
|
||||
|
||||
# Match the view_as_complex branch.
|
||||
@ -273,27 +645,27 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]
|
||||
"input": input_tensor,
|
||||
"inv_freq": inv_freq_candidate,
|
||||
"out": type_as_node,
|
||||
"unsqueeze_dim": unsqueeze_dim,
|
||||
}
|
||||
|
||||
|
||||
def _process_rope_v1(
|
||||
graph: GraphModule,
|
||||
def _process_explicit_rope(
|
||||
graph: GraphModule.graph,
|
||||
q_match: Dict[str, Node],
|
||||
k_match: Dict[str, Node],
|
||||
start_boundary: Node,
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]],
|
||||
rope_position_ids_cache: Dict[str, Node],
|
||||
) -> None:
|
||||
"""
|
||||
Process a region that matched the legacy RoPE pattern (v1).
|
||||
Inserts the custom op (flashinfer) and replaces the original add nodes.
|
||||
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
|
||||
and reuses those nodes when possible.
|
||||
Replace matched Explicit RoPE subgraph with `rope::torch_apply_rope_with_explicit_cos_sin`.
|
||||
"""
|
||||
q_node = q_match["raw_input"]
|
||||
k_node = k_match["raw_input"]
|
||||
cos_node = q_match["unsqueeze_cos"].args[0]
|
||||
sin_node = q_match["unsqueeze_sin"].args[0]
|
||||
cos_unsq = q_match["unsqueeze_cos"]
|
||||
sin_unsq = q_match["unsqueeze_sin"]
|
||||
cos_node = cos_unsq.args[0]
|
||||
sin_node = sin_unsq.args[0]
|
||||
unsq_dim = cos_unsq.args[1]
|
||||
add_node = q_match["add_node"]
|
||||
|
||||
# Sanity-check: ensure cos/sin nodes trace back to aten.cos/aten.sin.
|
||||
bfs(
|
||||
@ -309,167 +681,197 @@ def _process_rope_v1(
|
||||
boundary=start_boundary,
|
||||
)
|
||||
|
||||
# Infer input layout; default to [b, n, s, d] if inference fails.
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
need_transpose = isinstance(q_fake.shape[1], int)
|
||||
ad_logger.debug(
|
||||
f"Inferred RoPE input layout: [{'[b, n, s, d]' if need_transpose else '[b, s, n, d]'}]"
|
||||
)
|
||||
# Additional sanity check for the third dimension
|
||||
if need_transpose:
|
||||
if not isinstance(q_fake.shape[2], torch.SymInt):
|
||||
ad_logger.warning(
|
||||
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
|
||||
)
|
||||
need_transpose = True
|
||||
else:
|
||||
if not isinstance(q_fake.shape[1], torch.SymInt):
|
||||
ad_logger.warning(
|
||||
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
|
||||
)
|
||||
need_transpose = True
|
||||
else:
|
||||
ad_logger.warning("Unable to infer layout of q node. Defaulting to [b, n, s, d].")
|
||||
need_transpose = True
|
||||
|
||||
with graph.inserting_before(q_match["add_node"]):
|
||||
if need_transpose:
|
||||
q_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(q_node, 1, 2))
|
||||
k_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(k_node, 1, 2))
|
||||
q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,))
|
||||
k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,))
|
||||
else:
|
||||
q_for_op_contig, k_for_op_contig = q_node, k_node
|
||||
|
||||
head_dim = cos_node.meta["val"].shape[-1]
|
||||
half_head_dim = head_dim // 2
|
||||
|
||||
cache = rope_flash_cache
|
||||
cache_key = (cos_node, sin_node)
|
||||
if cache_key in cache:
|
||||
fused_cos_sin = cache[cache_key]
|
||||
else:
|
||||
cos_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
|
||||
)
|
||||
sin_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
|
||||
)
|
||||
fused_cos_sin = graph.call_function(
|
||||
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
|
||||
)
|
||||
fused_cos_sin = graph.call_function(operator.getitem, args=(fused_cos_sin, 0))
|
||||
fused_cos_sin = graph.call_function(
|
||||
torch.ops.aten.to.dtype, args=(fused_cos_sin, torch.float32)
|
||||
)
|
||||
cache[cache_key] = fused_cos_sin
|
||||
|
||||
position_ids = _get_position_ids(
|
||||
graph,
|
||||
q_for_op_contig,
|
||||
batch_dim=0,
|
||||
seq_dim=1,
|
||||
rope_position_ids_cache=rope_position_ids_cache,
|
||||
with graph.inserting_before(add_node):
|
||||
rope_node = graph.call_function(
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
|
||||
args=(q_node, k_node, cos_node, sin_node, unsq_dim),
|
||||
)
|
||||
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_for_op_contig, k_for_op_contig, position_ids, fused_cos_sin, True),
|
||||
with graph.inserting_after(rope_node):
|
||||
out_q = graph.call_function(operator.getitem, args=(rope_node, 0))
|
||||
out_k = graph.call_function(operator.getitem, args=(rope_node, 1))
|
||||
|
||||
out_q.meta["val"] = add_node.meta.get("val", None)
|
||||
out_k.meta["val"] = k_match["add_node"].meta.get("val", None)
|
||||
|
||||
q_match["add_node"].replace_all_uses_with(out_q)
|
||||
k_match["add_node"].replace_all_uses_with(out_k)
|
||||
|
||||
|
||||
def _process_complex_rope(
|
||||
graph: GraphModule.graph,
|
||||
q_match: Dict[str, Node],
|
||||
k_match: Dict[str, Node],
|
||||
) -> None:
|
||||
"""
|
||||
Replace matched Complex RoPE subgraph with `rope::torch_apply_rope_with_complex_freqs`.
|
||||
"""
|
||||
xq = q_match["input"]
|
||||
xk = k_match["input"]
|
||||
inv = q_match["inv_freq"]
|
||||
usdim = q_match["unsqueeze_dim"]
|
||||
out_node = q_match.get("out")
|
||||
|
||||
if inv != k_match["inv_freq"]:
|
||||
ad_logger.warning(
|
||||
"Mismatch of freqs_cis (inv_freq) between branches. Fail to match complex rope pattern"
|
||||
)
|
||||
return
|
||||
|
||||
with graph.inserting_before(out_node):
|
||||
rope_node = graph.call_function(
|
||||
torch.ops.rope.torch_apply_rope_with_complex_freqs,
|
||||
args=(xq, xk, inv, usdim),
|
||||
)
|
||||
|
||||
with graph.inserting_after(flash_node):
|
||||
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
|
||||
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
|
||||
with graph.inserting_after(rope_node):
|
||||
out_q = graph.call_function(operator.getitem, args=(rope_node, 0))
|
||||
out_k = graph.call_function(operator.getitem, args=(rope_node, 1))
|
||||
|
||||
if need_transpose:
|
||||
with graph.inserting_after(raw_q):
|
||||
new_q = graph.call_function(torch.ops.aten.transpose.int, args=(raw_q, 1, 2))
|
||||
with graph.inserting_after(raw_k):
|
||||
new_k = graph.call_function(torch.ops.aten.transpose.int, args=(raw_k, 1, 2))
|
||||
else:
|
||||
new_q, new_k = raw_q, raw_k
|
||||
out_q.meta["val"] = out_node.meta.get("val", None)
|
||||
out_k.meta["val"] = k_match["out"].meta.get("val", None)
|
||||
|
||||
new_q.meta["val"] = q_match["add_node"].meta.get("val", None)
|
||||
new_k.meta["val"] = k_match["add_node"].meta.get("val", None)
|
||||
|
||||
q_match["add_node"].replace_all_uses_with(new_q)
|
||||
k_match["add_node"].replace_all_uses_with(new_k)
|
||||
out_node.replace_all_uses_with(out_q)
|
||||
k_match["out"].replace_all_uses_with(out_k)
|
||||
|
||||
|
||||
def _process_rope_v2(
|
||||
def _process_input_interleave_rope(
|
||||
graph: GraphModule,
|
||||
q_match: Dict[str, Node],
|
||||
k_match: Dict[str, Node],
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]],
|
||||
rope_position_ids_cache: Dict[str, Node],
|
||||
) -> None:
|
||||
"""
|
||||
Process a region that matched the complex-multiplication RoPE pattern (v2).
|
||||
Inserts the custom op (flashinfer) after extracting frequency information,
|
||||
and replaces the original type_as nodes.
|
||||
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
|
||||
and reuses those nodes when possible.
|
||||
Replace a matched DS-style RoPE subgraph with a call to rope::torch_apply_rope_with_qk_interleaving.
|
||||
Cache the one-time unsqueeze of cos/sin.
|
||||
"""
|
||||
q_node = q_match["input"]
|
||||
k_node = k_match["input"]
|
||||
inv_freq_node = q_match["inv_freq"]
|
||||
q_node = q_match["raw_input"]
|
||||
k_node = k_match["raw_input"]
|
||||
cos_node = q_match["unsqueeze_cos"].args[0]
|
||||
sin_node = q_match["unsqueeze_sin"].args[0]
|
||||
# A patch for the case when q_output appears before k_input in the graph
|
||||
# Move q_output down right before its first user so that graph remains in
|
||||
# topological order after inserting the apply rope custom op
|
||||
q_match["add_node"] = _move_node_before_first_user(q_match["add_node"])
|
||||
|
||||
if inv_freq_node != k_match["inv_freq"]:
|
||||
raise RuntimeError("Mismatch of freqs_cis (inv_freq) between branches.")
|
||||
# Infer unsqueeze_dim from layout
|
||||
unsq_dim = 1
|
||||
fake = q_node.meta.get("val", None)
|
||||
if fake is not None and len(fake.shape) == 4:
|
||||
# if shape[1] symbolic, it's [B, S, N, D] => BSND -> head dim is 2
|
||||
if isinstance(fake.shape[1], torch.SymInt):
|
||||
unsq_dim = 2
|
||||
else:
|
||||
unsq_dim = 1
|
||||
|
||||
# Sanity check that input layout is BSND (no transpose needed).
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
|
||||
with graph.inserting_after(_get_last_node([q_node, k_node, cos_node, sin_node])):
|
||||
ds_node = graph.call_function(
|
||||
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
|
||||
args=(q_node, k_node, cos_node, sin_node, unsq_dim),
|
||||
)
|
||||
|
||||
with graph.inserting_after(ds_node):
|
||||
q_out = graph.call_function(operator.getitem, args=(ds_node, 0))
|
||||
k_out = graph.call_function(operator.getitem, args=(ds_node, 1))
|
||||
|
||||
q_out.meta["val"] = q_match["add_node"].meta.get("val", None)
|
||||
k_out.meta["val"] = k_match["add_node"].meta.get("val", None)
|
||||
|
||||
q_match["add_node"].replace_all_uses_with(q_out)
|
||||
k_match["add_node"].replace_all_uses_with(k_out)
|
||||
|
||||
|
||||
def _move_node_before_first_user(node: Node) -> Node:
|
||||
"""
|
||||
Remove `node` from the graph and re-insert a clone of it immediately
|
||||
before its earliest user. Returns the new node.
|
||||
|
||||
If `node` has no users, or is already right before its first user,
|
||||
this is a no-op and returns the original node.
|
||||
"""
|
||||
graph = node.graph
|
||||
ordering = list(graph.nodes)
|
||||
|
||||
users = list(node.users)
|
||||
if not users:
|
||||
return node
|
||||
|
||||
# locate the earliest user in the current ordering
|
||||
first_user = min(users, key=lambda u: ordering.index(u))
|
||||
if ordering.index(node) == ordering.index(first_user) - 1:
|
||||
return node
|
||||
|
||||
with graph.inserting_before(first_user):
|
||||
new_node = graph.node_copy(node, lambda n: n)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
return new_node
|
||||
|
||||
|
||||
def _get_last_node(nodes: Sequence[Node]) -> Node:
|
||||
"""
|
||||
Given a list of FX Nodes,
|
||||
return the one that appears last in the graph's execution order.
|
||||
"""
|
||||
if not nodes:
|
||||
raise ValueError("`nodes` must be a non-empty sequence of FX Node objects")
|
||||
|
||||
graph = nodes[0].graph
|
||||
ordering = list(graph.nodes)
|
||||
|
||||
# Sanity check that all nodes are in same graph
|
||||
valid = [n for n in nodes if n in ordering]
|
||||
if not valid:
|
||||
raise ValueError("None of the provided nodes belong to the same graph")
|
||||
|
||||
last = max(valid, key=lambda n: ordering.index(n))
|
||||
return last
|
||||
|
||||
|
||||
def _validate_rope_inputs(q_node: Node, k_node: Node) -> bool:
|
||||
"""
|
||||
Validates that:
|
||||
- The last dimension (head_dim) of both q and k is a multiple of 64.
|
||||
- The dtype of q and k is half precision (bfloat16 or float16).
|
||||
- Layout should be [B,S,N,D] (dim 1 should be symbolic)
|
||||
"""
|
||||
for name, node in [("q", q_node), ("k", k_node)]:
|
||||
fake_val = node.meta.get("val", None)
|
||||
if fake_val is None:
|
||||
ad_logger.warning(
|
||||
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
|
||||
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
|
||||
f"Meta['val'] for {name} not available; skipping RoPE transformation."
|
||||
)
|
||||
else:
|
||||
ad_logger.warning(
|
||||
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
|
||||
cache = rope_flash_cache
|
||||
if inv_freq_node in cache:
|
||||
cos_sin_flash = cache[inv_freq_node]
|
||||
else:
|
||||
# Compute the fused cosine/sine cache.
|
||||
with graph.inserting_after(inv_freq_node):
|
||||
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
|
||||
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
|
||||
with graph.inserting_after(real_part):
|
||||
cos_sin_flash_3d = graph.call_function(
|
||||
torch.ops.aten.cat, args=((real_part, imag_part), -1)
|
||||
# Check dtype
|
||||
if fake_val.dtype not in (torch.float16, torch.bfloat16):
|
||||
ad_logger.warning(
|
||||
f"""{name} tensor is {fake_val.dtype},
|
||||
expected half precision (float16 or bfloat16). Skipping RoPE transformation."""
|
||||
)
|
||||
with graph.inserting_after(cos_sin_flash_3d):
|
||||
cos_sin_flash = graph.call_function(operator.getitem, args=(cos_sin_flash_3d, 0))
|
||||
with graph.inserting_after(cos_sin_flash):
|
||||
cos_sin_flash = graph.call_function(
|
||||
torch.ops.aten.to.dtype, args=(cos_sin_flash, torch.float32)
|
||||
return False
|
||||
|
||||
# Check head_dim
|
||||
if len(fake_val.shape) < 1:
|
||||
ad_logger.warning(f"{name} tensor has invalid shape {fake_val.shape}.")
|
||||
return False
|
||||
head_dim = fake_val.shape[-1]
|
||||
if isinstance(head_dim, int) and head_dim % 64 != 0:
|
||||
ad_logger.warning(
|
||||
f"{name} head_dim = {head_dim} is not a multiple of 64. Skipping RoPE transformation."
|
||||
)
|
||||
cache[inv_freq_node] = cos_sin_flash
|
||||
return False
|
||||
|
||||
with graph.inserting_before(q_match["out"]):
|
||||
position_ids = _get_position_ids(
|
||||
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=rope_position_ids_cache
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_node, k_node, position_ids, cos_sin_flash, False),
|
||||
)
|
||||
# Check shape
|
||||
if not isinstance(fake_val.shape[1], torch.SymInt):
|
||||
ad_logger.warning(
|
||||
f"{name} has shape {fake_val.shape} that is not supported. Only support [B, S, N, D] layout.\
|
||||
Skipping RoPE transformation."
|
||||
)
|
||||
return False
|
||||
|
||||
with graph.inserting_after(flash_node):
|
||||
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
|
||||
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
|
||||
|
||||
raw_q.meta["val"] = q_match["out"].meta.get("val", None)
|
||||
raw_k.meta["val"] = k_match["out"].meta.get("val", None)
|
||||
|
||||
q_match["out"].replace_all_uses_with(raw_q)
|
||||
k_match["out"].replace_all_uses_with(raw_k)
|
||||
return True
|
||||
|
||||
|
||||
def _get_position_ids(
|
||||
@ -491,22 +893,17 @@ def _get_position_ids(
|
||||
|
||||
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, batch_dim))
|
||||
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, seq_dim))
|
||||
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
|
||||
|
||||
# Retrieve device information, ensuring it is a torch.device.
|
||||
device = q_node.meta.get("device", "cpu")
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
# Build positions: arange(sym_seq) -> view -> expand -> flatten.
|
||||
positions_node = graph.call_function(
|
||||
position_ids = graph.call_function(
|
||||
torch.ops.aten.arange,
|
||||
args=(sym_seq,),
|
||||
args=(bs_seq,),
|
||||
kwargs={"dtype": torch.float32, "device": device, "pin_memory": False},
|
||||
)
|
||||
positions_node = graph.call_function(torch.ops.aten.view, args=(positions_node, (1, -1)))
|
||||
positions_node = graph.call_function(
|
||||
torch.ops.aten.expand, args=(positions_node, (sym_batch, -1))
|
||||
)
|
||||
position_ids = graph.call_function(torch.ops.aten.flatten, args=(positions_node,))
|
||||
rope_position_ids_cache["position_ids"] = position_ids
|
||||
return position_ids
|
||||
|
||||
@ -25,12 +25,14 @@ from .library import (
|
||||
insert_cached_attention,
|
||||
match_attention_layout,
|
||||
match_causal_attn_mask,
|
||||
match_complex_rope,
|
||||
match_eager_attention,
|
||||
match_explicit_rope,
|
||||
match_grouped_attention,
|
||||
match_moe_pattern,
|
||||
match_repeat_kv,
|
||||
match_rope_v1,
|
||||
match_rope_v2,
|
||||
match_rope_layout,
|
||||
optimize_rope,
|
||||
quantize,
|
||||
resize_kv_cache,
|
||||
)
|
||||
@ -122,10 +124,10 @@ class InferenceOptimizer:
|
||||
egm = match_attention_layout(egm, self.attention_op)
|
||||
|
||||
# Match rope
|
||||
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
|
||||
egm = match_rope_v1(egm)
|
||||
egm = match_rope_v2(egm)
|
||||
egm = match_explicit_rope(egm)
|
||||
egm = match_complex_rope(egm)
|
||||
# Match RoPE layout expected by our backend
|
||||
egm = match_rope_layout(egm, self.attention_op.get_attention_layout())
|
||||
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
@ -134,6 +136,10 @@ class InferenceOptimizer:
|
||||
# eliminate redundant transpose operations
|
||||
egm = eliminate_redundant_transposes(egm)
|
||||
|
||||
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
|
||||
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
|
||||
egm = optimize_rope(egm)
|
||||
|
||||
# run TP sharding across ranks
|
||||
egm = column_row_shard(egm, local_rank, world_size)
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Common utils for torch fx graph transformation."""
|
||||
|
||||
import operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@ -320,3 +321,22 @@ def bfs(
|
||||
visited.add(next_node)
|
||||
queue.append(next_node)
|
||||
raise RuntimeError(f"Could not find node with target condition {target}.")
|
||||
|
||||
|
||||
def extract_output_tuple(node: Node, count: int = 2):
|
||||
"""
|
||||
Extract up to `count` outputs from a tuple-producing node.
|
||||
Returns a list of length `count`, with None if an output isn't found.
|
||||
"""
|
||||
results = []
|
||||
for idx in range(count):
|
||||
user_node = next(
|
||||
(
|
||||
u
|
||||
for u in node.users
|
||||
if u.op == "call_function" and u.target == operator.getitem and u.args[1] == idx
|
||||
),
|
||||
None,
|
||||
)
|
||||
results.append(user_node)
|
||||
return results
|
||||
|
||||
@ -51,7 +51,7 @@ def run_test(
|
||||
num_params_gm = count_parameters(gm)
|
||||
|
||||
assert num_params_model == num_params_gm
|
||||
assert all_close(y_model, y_gm, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
|
||||
|
||||
# graph transformation + check
|
||||
gm_transformed = transform(gm, *args)
|
||||
@ -71,9 +71,7 @@ def run_test(
|
||||
|
||||
if strict_loading:
|
||||
# check if output equals without loading state dict
|
||||
assert all_close(y_model, y_transformed, atol=atol, rtol=rtol), (
|
||||
f"{y_model=}, {y_transformed=}"
|
||||
)
|
||||
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)
|
||||
|
||||
if test_load_hook:
|
||||
# check if loading hook works from original state dict
|
||||
@ -83,9 +81,7 @@ def run_test(
|
||||
|
||||
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
|
||||
y_loaded_from_original = gm_transformed(x)
|
||||
assert all_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol), (
|
||||
f"{y_model=}, {y_loaded_from_original=}"
|
||||
)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
|
||||
|
||||
# check if loading hook works from state_dict of a transformed model
|
||||
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
|
||||
@ -95,9 +91,7 @@ def run_test(
|
||||
|
||||
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
|
||||
y_loaded_from_transformed = gm_transformed(x)
|
||||
assert all_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol), (
|
||||
f"{y_model=}, {y_loaded_from_transformed=}"
|
||||
)
|
||||
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
|
||||
|
||||
# check if we can still export the model as expected
|
||||
torch_export(gm_transformed, args=(x,))
|
||||
|
||||
@ -222,3 +222,55 @@ def _hf_model_dir_or_hub_id(
|
||||
return hf_model_dir
|
||||
else:
|
||||
return hf_hub_id
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb_explicit(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_complex(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
|
||||
unsqueeze_dim: int = 2,
|
||||
):
|
||||
# Reshape the inputs to pair the last dimension.
|
||||
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
|
||||
freqs = freqs_cis.unsqueeze(unsqueeze_dim)
|
||||
xq_out = torch.view_as_real(xq_complex * freqs).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_complex * freqs).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L339
|
||||
def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
||||
"""
|
||||
Apply rotary positional embeddings by interleaving Q/K ,
|
||||
indexing cos/sin tables with position_ids, and returning rotated q, k.
|
||||
cos: [seq_len, head_dim]
|
||||
sin: [seq_len, head_dim]
|
||||
"""
|
||||
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
||||
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
||||
b, h, s, d = q.shape
|
||||
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
b, h, s, d = k.shape
|
||||
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
@ -1,145 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-4, 1e-4),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_and_custom_rope_ops(dtype, atol, rtol, head_dim):
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
# Prepare rotary embedding values.
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
|
||||
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
|
||||
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
|
||||
|
||||
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
|
||||
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
|
||||
# For HF and the custom op: duplicated layout [seq_len, head_dim].
|
||||
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
|
||||
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Direct FlashInfer kernel call.
|
||||
query_flat = query.view(batch * seq_len, n_head * head_dim)
|
||||
key_flat = key.view(batch * seq_len, n_head * head_dim)
|
||||
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
|
||||
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
|
||||
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
|
||||
)
|
||||
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
|
||||
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
|
||||
|
||||
# HF implementation using apply_rotary_pos_emb.
|
||||
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
|
||||
q_for_hf = query.transpose(1, 2).clone()
|
||||
k_for_hf = key.transpose(1, 2).clone()
|
||||
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
q_hf, k_hf = apply_rotary_pos_emb(q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1)
|
||||
|
||||
# Convert outputs to [batch, seq_len, n_head, head_dim]
|
||||
q_hf = q_hf.transpose(1, 2).to(dtype)
|
||||
k_hf = k_hf.transpose(1, 2).to(dtype)
|
||||
|
||||
# Custom op call
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, True)
|
||||
|
||||
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Version 2: complex multiplication approach
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-5, 1e-5),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_complex_rotary(dtype, atol, rtol, head_dim):
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
|
||||
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
|
||||
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
out_q_v2, out_k_v2 = apply_rotary_emb(query, key, freqs_cis)
|
||||
|
||||
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
|
||||
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
|
||||
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
|
||||
|
||||
# q/k of llama4 rope is interleaved
|
||||
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, False)
|
||||
|
||||
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)
|
||||
@ -0,0 +1,294 @@
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
from _model_test_utils import (
|
||||
apply_rotary_pos_emb_complex,
|
||||
apply_rotary_pos_emb_ds,
|
||||
apply_rotary_pos_emb_explicit,
|
||||
)
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-4, 1e-4),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim):
|
||||
"""
|
||||
Verify FlashInfer's Neox RoPE kernel against HF's apply_rotary_pos_emb:
|
||||
- Q/K: [B, S, N, D] non-interleaved half-precision.
|
||||
- cos_sin_cache: [S, D] = [cos||sin] concatenated.
|
||||
- HF path: Q/K → [B, N, S, D], cos_new/sin_new: [S, D] duplicated, then broadcast to [B, S, D].
|
||||
"""
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
# Prepare rotary embedding values.
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
|
||||
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
|
||||
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
|
||||
|
||||
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
|
||||
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
|
||||
cos_sin_cache_expand = (
|
||||
cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1)
|
||||
) # [batch * seq_len, head_dim]
|
||||
# For HF and the custom op: duplicated layout [seq_len, head_dim].
|
||||
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
|
||||
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Direct FlashInfer kernel call.
|
||||
query_flat = query.view(batch * seq_len, n_head * head_dim)
|
||||
key_flat = key.view(batch * seq_len, n_head * head_dim)
|
||||
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
|
||||
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
|
||||
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
|
||||
)
|
||||
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
|
||||
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
|
||||
|
||||
# HF implementation using apply_rotary_pos_emb.
|
||||
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
|
||||
q_for_hf = query.transpose(1, 2).clone()
|
||||
k_for_hf = key.transpose(1, 2).clone()
|
||||
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
q_hf, k_hf = apply_rotary_pos_emb_explicit(
|
||||
q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
# Convert outputs to [batch, seq_len, n_head, head_dim]
|
||||
q_hf = q_hf.transpose(1, 2).to(dtype)
|
||||
k_hf = k_hf.transpose(1, 2).to(dtype)
|
||||
|
||||
# Custom op call
|
||||
positions_flat = torch.arange(batch * seq_len, device=device)
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(
|
||||
query, key, positions_flat, cos_sin_cache_expand, True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-5, 1e-5),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim):
|
||||
"""
|
||||
Check FlashInfer's RoPE matches the complex-multiplication approach:
|
||||
- Q/K: [B, S, N, D] non-interleaved half-precision.
|
||||
- freqs_cis: [B, S, D/2] complex polar values.
|
||||
- flashinfer uses cos_sin_cache: [S, D] interleaved from real/imag of freqs_cis.
|
||||
"""
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
|
||||
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
|
||||
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
out_q_v2, out_k_v2 = apply_rotary_pos_emb_complex(query, key, freqs_cis)
|
||||
|
||||
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
|
||||
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
|
||||
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
|
||||
cos_sin_cache_expand = (
|
||||
cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1)
|
||||
) # [batch * seq_len, head_dim]
|
||||
|
||||
# q/k of llama4 rope is interleaved
|
||||
positions_flat = torch.arange(batch * seq_len, device=device)
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(
|
||||
query, key, positions_flat, cos_sin_cache_expand, False
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Copy of TritonWithFlattenedInputs._precompute_freqs_cis
|
||||
def precompute_freqs_cis_interleaved(
|
||||
seq_len: int, head_dim: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precompute interleaved cosine and sine frequency cache for rotary position embeddings (RoPE).
|
||||
|
||||
Returns a tensor of shape [seq_len, head_dim//2, 2], where the last dimension
|
||||
alternates [cos, sin] values for each rotary frequency.
|
||||
cache[s, i, 0] == cos(position=s · inv_freq[i])
|
||||
cache[s, i, 1] == sin(position=s · inv_freq[i]).
|
||||
"""
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
|
||||
t = torch.arange(seq_len, device=device)
|
||||
angles = t.unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
freqs_cis = torch.polar(torch.ones_like(angles), angles)
|
||||
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
|
||||
return cache.to(dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layout", ["bsnd", "bnsd"])
|
||||
@pytest.mark.parametrize("head_dim", [64, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, atol, rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-4, 1e-4),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"],
|
||||
)
|
||||
def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol):
|
||||
"""
|
||||
Validate custom Triton apply_rope_with_input_pos against HF's apply_rotary_pos_emb:
|
||||
- Q/K: layout 'bsnd'→[B,S,N,D] or 'bnsd'→[B,N,S,D], non-interleaved half-precision.
|
||||
- cosin_cache: [S, D/2, 2] interleaved [cos,sin].
|
||||
- HF path: cos_full/sin_full: [S, D] then expanded to [B, S, D].
|
||||
"""
|
||||
device = "cuda"
|
||||
batch, seq_len, n_head = 2, 4, 3
|
||||
|
||||
# build cache and per-batch zero positions
|
||||
cosin_cache = precompute_freqs_cis_interleaved(seq_len, head_dim, dtype, device) # [S, D/2, 2]
|
||||
positions = torch.zeros(batch, dtype=torch.int32, device=device)
|
||||
|
||||
if layout == "bsnd":
|
||||
q = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
k = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
unsq = 2
|
||||
else: # "bnsd"
|
||||
q = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device)
|
||||
k = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device)
|
||||
unsq = 1
|
||||
|
||||
# build HF float32 cos/sin full tensors
|
||||
cos_f32 = cosin_cache[..., 0].to(torch.float32) # [S, H/2]
|
||||
sin_f32 = cosin_cache[..., 1].to(torch.float32) # [S, H/2]
|
||||
cos_full = torch.cat([cos_f32, cos_f32], dim=1) # [S, H]
|
||||
sin_full = torch.cat([sin_f32, sin_f32], dim=1) # [S, H]
|
||||
cos_exp = cos_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H]
|
||||
sin_exp = sin_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H]
|
||||
|
||||
# HF reference in float32, then cast back
|
||||
q_f32, k_f32 = apply_rotary_pos_emb_explicit(
|
||||
q.to(torch.float32), k.to(torch.float32), cos_exp, sin_exp, unsqueeze_dim=unsq
|
||||
)
|
||||
q_hf = q_f32.to(dtype)
|
||||
k_hf = k_f32.to(dtype)
|
||||
|
||||
q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout)
|
||||
k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout)
|
||||
|
||||
torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def inverse_interleave_permute_for_rotary(x: torch.Tensor) -> torch.Tensor:
|
||||
b, h, s, d = x.shape
|
||||
x = x.view(b, h, s, 2, d // 2)
|
||||
x = x.transpose(4, 3)
|
||||
return x.reshape(b, h, s, d)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, atol, rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-4, 1e-4),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"],
|
||||
)
|
||||
def test_ds_impl_and_hf_impl(dtype, head_dim, atol, rtol):
|
||||
"""
|
||||
Ensure Deepseek's interleaved-Q/K RoPE matches HF apply_rotary_pos_emb:
|
||||
- DS Q/K: [B, N, S, D] channel-interleaved in last dim.
|
||||
- cos_new/sin_new: [S, D] duplicated real values.
|
||||
- HF path: Q/K → [B,N,S,D], cos_expand/sin_expand: [B,S,D], unsqueezed at dim=1.
|
||||
"""
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch, seq_len)
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
|
||||
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
|
||||
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
|
||||
# duplicate to shape [seq_len, head_dim]
|
||||
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
|
||||
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# HF torch expects inputs of shape [B, N, S, D]
|
||||
q_for_hf = query.transpose(1, 2).clone()
|
||||
k_for_hf = key.transpose(1, 2).clone()
|
||||
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
q_rotated_hf, k_rotated_hf = apply_rotary_pos_emb_explicit(
|
||||
q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1
|
||||
)
|
||||
q_rotated_hf = q_rotated_hf.transpose(1, 2).to(torch.float32)
|
||||
k_rotated_hf = k_rotated_hf.transpose(1, 2).to(torch.float32)
|
||||
|
||||
q_for_hf2 = inverse_interleave_permute_for_rotary(q_for_hf.clone())
|
||||
k_for_hf2 = inverse_interleave_permute_for_rotary(k_for_hf.clone())
|
||||
|
||||
# adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L134
|
||||
t = torch.arange(seq_len, device=device, dtype=torch.float32)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos_ds = emb.cos()
|
||||
sin_ds = emb.sin()
|
||||
|
||||
torch.testing.assert_close(cos_ds, cos_new, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(sin_ds, sin_new, rtol=rtol, atol=atol)
|
||||
|
||||
q_rotated_hf2, k_rotated_hf2 = apply_rotary_pos_emb_ds(
|
||||
q_for_hf2, k_for_hf2, cos_new, sin_new, position_ids, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
torch.testing.assert_close(q_rotated_hf2.transpose(1, 2), q_rotated_hf, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_rotated_hf2.transpose(1, 2), k_rotated_hf, rtol=rtol, atol=atol)
|
||||
@ -1,211 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
|
||||
match_rope_v1,
|
||||
match_rope_v2,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _precompute_freqs_cis(seq_len: int, head_dim: int, rope_theta: float):
|
||||
dtype = torch.float32
|
||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype)
|
||||
sin = emb.sin().to(dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
|
||||
):
|
||||
# Reshape the inputs to pair the last dimension.
|
||||
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
|
||||
xq_out = torch.view_as_real(xq_complex * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_complex * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def _precompute_freqs_cis_v2(seq_len: int, head_dim: int, rope_theta: float):
|
||||
"""
|
||||
Compute the frequency tensor for the complex multiplication RoPE variant.
|
||||
Returns a complex tensor of shape (seq_len, head_dim//2).
|
||||
"""
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
|
||||
)
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
|
||||
# Create a complex tensor from magnitude=1 and the computed angles.
|
||||
freqs_cis = torch.polar(torch.ones_like(angles), angles)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class RotaryModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layout: str = "BNSD",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.layout = layout
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
|
||||
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
q = self.linear_q(x)
|
||||
k = self.linear_k(x)
|
||||
|
||||
batch, seq, _ = q.shape
|
||||
q = q.view(batch, seq, self.num_heads, self.head_dim)
|
||||
k = k.view(batch, seq, self.num_kv_heads, self.head_dim)
|
||||
|
||||
if self.layout == "BNSD":
|
||||
q = q.permute(0, 2, 1, 3).contiguous() # [B, N, S, D]
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
unsqueeze_dim = 1
|
||||
else: # BSND
|
||||
unsqueeze_dim = 2
|
||||
|
||||
cos, sin = _precompute_freqs_cis(seq, self.head_dim, rope_theta=10000)
|
||||
cos = cos.to(q.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
sin = sin.to(q.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
|
||||
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim)
|
||||
if self.layout == "BNSD":
|
||||
# [B, N, S, D] -> [B, S, N*D]
|
||||
q_embed = q_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
|
||||
k_embed = k_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
|
||||
else: # BSND
|
||||
q_embed = q_embed.reshape(batch, seq, -1)
|
||||
k_embed = k_embed.reshape(batch, seq, -1)
|
||||
|
||||
output = torch.cat([q_embed, k_embed], dim=-1)
|
||||
return output.to(torch.float16)
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
class RotaryModelV2(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, max_seq_len: int, num_heads: int, num_kv_heads: int):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
|
||||
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch, seq, _ = x.shape
|
||||
|
||||
q = self.linear_q(x).view(batch, seq, self.num_heads, self.head_dim)
|
||||
k = self.linear_k(x).view(batch, seq, self.num_kv_heads, self.head_dim)
|
||||
|
||||
freqs_cis = _precompute_freqs_cis_v2(seq, self.head_dim, rope_theta=10000)
|
||||
freqs_cis = freqs_cis.to(x.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
|
||||
q_embed, k_embed = apply_rotary_emb(q, k, freqs_cis)
|
||||
|
||||
q_embed = q_embed.reshape(batch, seq, -1)
|
||||
k_embed = k_embed.reshape(batch, seq, -1)
|
||||
|
||||
output = torch.cat([q_embed, k_embed], dim=-1)
|
||||
return output.to(torch.float16)
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layout", ["BNSD", "BSND"])
|
||||
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
|
||||
@torch.inference_mode()
|
||||
def test_match_rope(layout, num_heads, num_kv_heads):
|
||||
batch_size, seq_len = 8, 16
|
||||
hidden_size = 512
|
||||
max_position_embeddings = seq_len
|
||||
|
||||
model = RotaryModel(
|
||||
hidden_size, max_position_embeddings, num_heads, num_kv_heads, layout=layout
|
||||
).to("cuda", dtype=torch.float16)
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_rope_v1,
|
||||
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
|
||||
@torch.inference_mode()
|
||||
def test_match_rope_v2(num_heads, num_kv_heads):
|
||||
batch_size, seq_len = 8, 16
|
||||
hidden_size = 512
|
||||
max_position_embeddings = seq_len
|
||||
|
||||
model = RotaryModelV2(hidden_size, max_position_embeddings, num_heads, num_kv_heads).to(
|
||||
"cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_rope_v2,
|
||||
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
@ -0,0 +1,429 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from _model_test_utils import (
|
||||
apply_rotary_pos_emb_complex,
|
||||
apply_rotary_pos_emb_ds,
|
||||
apply_rotary_pos_emb_explicit,
|
||||
)
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
|
||||
match_complex_rope,
|
||||
match_explicit_rope,
|
||||
match_rope_layout,
|
||||
optimize_rope,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import extract_output_tuple, is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def _precompute_freqs_cis_explicit(seq_len: int, head_dim: int, rope_theta: float):
|
||||
dtype = torch.float32
|
||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype)
|
||||
sin = emb.sin().to(dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def _precompute_freqs_cis_complex(seq_len: int, head_dim: int, rope_theta: float):
|
||||
"""
|
||||
Compute the frequency tensor for the complex multiplication RoPE variant.
|
||||
Returns a complex tensor of shape (seq_len, head_dim//2).
|
||||
"""
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
|
||||
)
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
|
||||
# Create a complex tensor from magnitude=1 and the computed angles.
|
||||
freqs_cis = torch.polar(torch.ones_like(angles), angles)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class RoPEModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
variant: str = "explicit", # "explicit" or "complex"
|
||||
mode: str = "match", # "match" or "optimize"
|
||||
layout: str = "BNSD", # "BNSD" or "BSND"
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.variant = variant
|
||||
self.mode = mode
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.layout = layout
|
||||
|
||||
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
|
||||
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, s, _ = x.shape
|
||||
q = self.linear_q(x)
|
||||
k = self.linear_k(x)
|
||||
|
||||
if self.variant == "explicit":
|
||||
# reshape and permute if BNSD layout
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s, self.num_kv_heads, self.head_dim)
|
||||
if self.layout == "BNSD":
|
||||
q = q.permute(0, 2, 1, 3).contiguous()
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
unsq_dim = 1
|
||||
else:
|
||||
unsq_dim = 2
|
||||
|
||||
cos, sin = _precompute_freqs_cis_explicit(s, self.head_dim, rope_theta=10000)
|
||||
cos = cos.to(x.device).unsqueeze(0).expand(b, -1, -1)
|
||||
sin = sin.to(x.device).unsqueeze(0).expand(b, -1, -1)
|
||||
|
||||
if self.mode == "match":
|
||||
q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim)
|
||||
else: # optimize
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin(
|
||||
q, k, cos, sin, unsq_dim
|
||||
)
|
||||
|
||||
# revert layout and flatten
|
||||
if self.layout == "BNSD":
|
||||
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
else:
|
||||
q_out = q_out.reshape(b, s, -1)
|
||||
k_out = k_out.reshape(b, s, -1)
|
||||
|
||||
else: # complex variant
|
||||
q = q.view(b, s, self.num_heads, self.head_dim)
|
||||
k = k.view(b, s, self.num_kv_heads, self.head_dim)
|
||||
if self.layout == "BNSD":
|
||||
q = q.permute(0, 2, 1, 3).contiguous()
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
unsq_dim = 1
|
||||
else:
|
||||
unsq_dim = 2
|
||||
|
||||
freqs = _precompute_freqs_cis_complex(s, self.head_dim, rope_theta=10000)
|
||||
freqs = freqs.to(x.device).unsqueeze(0).expand(b, -1, -1)
|
||||
|
||||
if self.mode == "match":
|
||||
q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim)
|
||||
else:
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs(
|
||||
q, k, freqs, unsq_dim
|
||||
)
|
||||
|
||||
# revert layout and flatten
|
||||
if self.layout == "BNSD":
|
||||
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
else:
|
||||
q_out = q_out.reshape(b, s, -1)
|
||||
k_out = k_out.reshape(b, s, -1)
|
||||
|
||||
out = torch.cat([q_out, k_out], dim=-1)
|
||||
return out.to(torch.float16) if self.mode == "match" else out
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"transformation,variant,layout,batch_size,seq_len,num_heads,num_kv_heads,atol,rtol, target_layout",
|
||||
[
|
||||
("match", "explicit", "BNSD", 8, 16, 8, 8, 1e-2, 1e-2, None),
|
||||
("match", "explicit", "BSND", 8, 16, 8, 4, 1e-2, 1e-2, None),
|
||||
("match", "complex", "BNSD", 8, 16, 8, 8, 1e-3, 1e-3, None),
|
||||
("match", "complex", "BSND", 8, 16, 8, 4, 1e-3, 1e-3, None),
|
||||
("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
|
||||
("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BNSD"),
|
||||
("match_layout", "complex", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
|
||||
("match_layout", "complex", "BSND", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
|
||||
pytest.param(
|
||||
"optimize",
|
||||
"explicit",
|
||||
"BNSD",
|
||||
4,
|
||||
12,
|
||||
8,
|
||||
8,
|
||||
1e-3,
|
||||
1e-3,
|
||||
None,
|
||||
marks=pytest.mark.xfail(
|
||||
reason="flashinfer op does not support BNSD layout", strict=True
|
||||
),
|
||||
),
|
||||
("optimize", "explicit", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None),
|
||||
pytest.param(
|
||||
"optimize",
|
||||
"complex",
|
||||
"BNSD",
|
||||
4,
|
||||
12,
|
||||
8,
|
||||
8,
|
||||
1e-3,
|
||||
1e-3,
|
||||
None,
|
||||
marks=pytest.mark.xfail(
|
||||
reason="flashinfer op does not support BNSD layout", strict=True
|
||||
),
|
||||
),
|
||||
("optimize", "complex", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_rope_variants(
|
||||
transformation,
|
||||
variant,
|
||||
layout,
|
||||
batch_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
atol,
|
||||
rtol,
|
||||
target_layout,
|
||||
):
|
||||
hidden_size = 512
|
||||
model = RoPEModel(
|
||||
hidden_size,
|
||||
seq_len,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
variant=variant,
|
||||
mode=transformation,
|
||||
layout=layout or "BNSD",
|
||||
).to("cuda", torch.float16)
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dyn = model.get_dynamic_shapes()
|
||||
|
||||
if transformation == "match":
|
||||
fn = match_explicit_rope if variant == "explicit" else match_complex_rope
|
||||
check_op = (
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin
|
||||
if variant == "explicit"
|
||||
else torch.ops.rope.torch_apply_rope_with_complex_freqs
|
||||
)
|
||||
|
||||
def checker(gm):
|
||||
return any(is_op(n, check_op) for n in gm.graph.nodes)
|
||||
|
||||
elif transformation == "match_layout":
|
||||
fn = match_rope_layout
|
||||
|
||||
def checker(gm):
|
||||
for n in gm.graph.nodes:
|
||||
if is_op(
|
||||
n,
|
||||
{
|
||||
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
|
||||
torch.ops.rope.torch_apply_rope_with_complex_freqs,
|
||||
},
|
||||
):
|
||||
q_arg, k_arg, *rest = n.args
|
||||
if not (
|
||||
is_op(q_arg, torch.ops.aten.contiguous)
|
||||
and is_op(k_arg, torch.ops.aten.contiguous)
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
|
||||
old_q, old_k = extract_output_tuple(n, 2)
|
||||
if old_q is None or old_k is None:
|
||||
matched = False
|
||||
break
|
||||
q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users)
|
||||
k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users)
|
||||
matched = q_transposed and k_transposed
|
||||
|
||||
return matched if layout != target_layout else not matched
|
||||
|
||||
else:
|
||||
fn = optimize_rope
|
||||
|
||||
def checker(gm):
|
||||
return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes)
|
||||
|
||||
if target_layout:
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
fn,
|
||||
checker,
|
||||
lambda n: n,
|
||||
atol, # atol
|
||||
rtol, # rtol
|
||||
True, # test_load_hook
|
||||
True, # strict_loading
|
||||
dyn, # dynamic_shapes
|
||||
target_layout,
|
||||
)
|
||||
else:
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
fn,
|
||||
checker,
|
||||
lambda n: n,
|
||||
atol, # atol
|
||||
rtol, # rtol
|
||||
True, # test_load_hook
|
||||
True, # strict_loading
|
||||
dyn, # dynamic_shapes
|
||||
)
|
||||
|
||||
|
||||
class DSRotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
t = torch.arange(max_position_embeddings, device=device, dtype=inv_freq.dtype)
|
||||
freqs = torch.outer(t, inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# returns [seq_len, head_dim] cos & sin
|
||||
return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)
|
||||
|
||||
|
||||
class DSModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size, max_seq, n_head, n_kv, layout="BNSD", mode: str = "match"):
|
||||
super().__init__()
|
||||
self.hdim = hidden_size // n_head
|
||||
self.layout = layout
|
||||
self.q_lin = torch.nn.Linear(hidden_size, n_head * self.hdim)
|
||||
self.k_lin = torch.nn.Linear(hidden_size, n_kv * self.hdim)
|
||||
self.rotary = DSRotaryEmbedding(self.hdim, max_seq, base=10000, device="cuda")
|
||||
self.mode = mode # "match" or "optimize"
|
||||
|
||||
def forward(self, x):
|
||||
b, s, _ = x.shape
|
||||
q = self.q_lin(x).view(b, s, -1, self.hdim)
|
||||
k = self.k_lin(x).view(b, s, -1, self.hdim)
|
||||
if self.layout == "BNSD":
|
||||
# to [B, N, S, D]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
unsq_dim = 1
|
||||
else:
|
||||
unsq_dim = 2
|
||||
cos, sin = self.rotary(x, seq_len=s)
|
||||
# build position_ids [B, S]
|
||||
pos_ids = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s)
|
||||
if self.mode == "match":
|
||||
q_out, k_out = apply_rotary_pos_emb_ds(q, k, cos, sin, pos_ids, unsqueeze_dim=unsq_dim)
|
||||
else:
|
||||
cos = cos[pos_ids]
|
||||
sin = sin[pos_ids]
|
||||
q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
|
||||
q, k, cos, sin, unsq_dim
|
||||
)
|
||||
if self.layout == "BNSD":
|
||||
# back to [B, S, N*D]
|
||||
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
|
||||
else:
|
||||
q_out = q_out.reshape(b, s, -1)
|
||||
k_out = k_out.reshape(b, s, -1)
|
||||
return torch.cat([q_out, k_out], dim=-1)
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layout,num_heads,num_kv_heads,mode, target_layout",
|
||||
[
|
||||
("BNSD", 8, 8, "match", None),
|
||||
("BSND", 8, 4, "match", None),
|
||||
("BNSD", 8, 8, "match_layout", "BNSD"),
|
||||
("BSND", 8, 4, "match_layout", "BNSD"),
|
||||
("BSND", 8, 4, "match_layout", "BSND"),
|
||||
("BNSD", 8, 4, "match_layout", "BSND"),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target_layout):
|
||||
batch, seq, hid = 4, 12, 512
|
||||
model = DSModel(hid, 16, num_heads, num_kv_heads, layout=layout, mode=mode)
|
||||
model = model.to("cuda", torch.float16)
|
||||
|
||||
x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
if mode == "match":
|
||||
transform = match_explicit_rope
|
||||
|
||||
def checker(gm):
|
||||
return any(
|
||||
is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving)
|
||||
for n in gm.graph.nodes
|
||||
)
|
||||
|
||||
else: # mode == "match_layout"
|
||||
transform = match_rope_layout
|
||||
|
||||
def checker(gm):
|
||||
for n in gm.graph.nodes:
|
||||
if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving):
|
||||
q_arg, k_arg, *rest = n.args
|
||||
if not (
|
||||
is_op(q_arg, torch.ops.aten.contiguous)
|
||||
and is_op(k_arg, torch.ops.aten.contiguous)
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
|
||||
old_q, old_k = extract_output_tuple(n, 2)
|
||||
if old_q is None or old_k is None:
|
||||
matched = False
|
||||
break
|
||||
q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users)
|
||||
k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users)
|
||||
matched = q_transposed and k_transposed
|
||||
|
||||
return matched if layout != target_layout else not matched
|
||||
|
||||
if target_layout:
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
transform,
|
||||
checker,
|
||||
lambda num_p: num_p,
|
||||
1e-3, # atol
|
||||
1e-3, # rtol
|
||||
True, # test_load_hook
|
||||
True, # strict_loading
|
||||
dynamic_shapes, # dynamic_shapes
|
||||
target_layout,
|
||||
)
|
||||
else:
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
transform,
|
||||
checker,
|
||||
lambda num_p: num_p,
|
||||
1e-3, # atol
|
||||
1e-3, # rtol
|
||||
True, # test_load_hook
|
||||
True, # strict_loading
|
||||
dynamic_shapes, # dynamic_shapes
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user