feat: [AutoDeploy] update rope matcher with minor variants (Deepseek) (#3638)

* add docstring to summarize current rope support

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor: replace call_method, adjust inserting order of cos_sin_cache calculation node

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add unit test for triton rope and ds rope

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* update rope matcher to match DS RoPE, add custom op for reference, add unit test case

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* cache cos[pos_idx].unsqueeze and sin[pos_idxs].unsqueeze

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor doc update

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* separate pattern matching and optimization for explicit and complex rope + minor updates

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* clean rope impl in repo

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* replace fused_flattened_mla_with_cache's rope impl with torch_apply_rope_with_qk_interleaving, update unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* separate layout infer and transpose to a new transformation

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* update rope_with_explicit_freqs and rope_with_input_interleaved to expose unsqueeze_dim and support match_rope_layout, add unit tests

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* solve merge conflict in transform.py, need to fix optimize_rope with cuda graph capture

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor clean up after rebase

Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>

* fix pre-commit

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* support map to bnsd layout and infer unsqueeze_dim from op

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix cos/sin not the same across prompts in the same batch issue when mapping to flashinfer op

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix for unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix custom op input/output node ordering issue for DeepSeek V3 rope

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* clean code

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* move flattening of cos_sin_cache to the graph, update flashinfer op docstring and test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* debug transform unit test failure

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

---------

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Ubuntu <201670829+Fridah-nv@users.noreply.github.com>
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Fridah-nv 2025-05-16 06:55:32 -07:00 committed by GitHub
parent f5b6d453aa
commit bce281d592
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1572 additions and 758 deletions

View File

@ -10,4 +10,5 @@ from .mla import *
from .quant import *
from .rope import *
from .torch_attention import *
from .torch_rope import *
from .triton_attention import *

View File

@ -21,10 +21,11 @@ def apply_rope_with_input_pos_flashinfer(
Tensors of shape [batch, seq_len, n_head, head_dim] (or a 3D variant)
in half precision. Note: head_dim must be a multiple of 64.
- position_ids (torch.Tensor):
Precomputed tensor of positional indices; it is shared across calls in the graph.
Precomputed tensor of positional indices indicating idx in cos_sin_cache for each token;
Shape [batch, seq_len] or [batch * seq_len]
- cos_sin_cache (torch.Tensor):
Precomputed fused tensor created by concatenating the first half of the cosine and sine
components derived from the inv_freq.
components derived from the inv_freq. Shape [max_seq_len, head_dim]. Must be float32.
- is_neox (bool):
Flag to indicate whether to invoke the FlashInfer kernel in Neox mode.
@ -35,14 +36,13 @@ def apply_rope_with_input_pos_flashinfer(
"""
q_shape = q.shape
k_shape = k.shape
batch_size, seq_len = q_shape[:2]
head_dim = cos_sin_cache.shape[-1]
q_flat = q.view(batch_size * seq_len, -1)
k_flat = k.view(batch_size * seq_len, -1)
position_ids = position_ids.view(-1).to(q.device)
num_nnz = position_ids.shape[0]
position_ids = position_ids.to(q.device)
q_flat = q.view(num_nnz, -1)
k_flat = k.view(num_nnz, -1)
query_rotated_flash, key_rotated_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
position_ids, q_flat, k_flat, head_dim, cos_sin_cache, is_neox=is_neox
@ -60,4 +60,4 @@ def apply_rope_with_input_pos_flashinfer_fake(
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
return q, k
return torch.empty_like(q), torch.empty_like(k)

View File

@ -17,7 +17,6 @@ from .attention_interface import (
PrepareMetadataCallable,
SequenceInfo,
)
from .torch_attention import apply_rotary_pos_emb_ds
from .triton_attention import _flattened_context_mha, _generate_mha
Constant = Union[int, float, str, None]
@ -74,26 +73,38 @@ def fused_flattened_mla_with_cache(
# Apply RoPE
if cos_sin_stacked.numel() > 0:
# Extract cos and sin from freqs_cis
cos = cos_sin_stacked[0, ...]
sin = cos_sin_stacked[1, ...]
cos_base = cos_sin_stacked[0, ...]
sin_base = cos_sin_stacked[1, ...]
# TODO: Use triton kernels for RoPE
# TODO: Add yarn support
for idx in range(seq_len.shape[0]):
(
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
) = apply_rotary_pos_emb_ds(
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
for i in range(seq_len.shape[0]):
start = seq_start[i]
length = seq_len[i]
# build position_ids
if s == 1:
idx = (input_pos[i] + length - 1).item()
pos_ids = torch.tensor(idx, device=cos_base.device)
else:
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
cos = cos_base[pos_ids] # [..., 1, head_dim]
sin = sin_base[pos_ids]
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
sin,
torch.arange(input_pos[idx] + seq_len[idx])[-1]
if s == 1
else torch.arange(input_pos[idx] + seq_len[idx]),
-2,
)
q_pe[start : start + length] = q_rot
k_pe[start : start + length] = k_rot
# Create query_states, key_states
query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d]
key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d]

View File

@ -138,92 +138,6 @@ def bsnd_grouped_sdpa_fake(
return torch.empty_like(query.contiguous())
# Function to apply rotary positional embeddings (RoPE)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
seq_len: int,
head_dim: int,
rope_theta: Optional[float] = None,
rope_scale: Optional[float] = None,
):
"""
Apply rotary positional embeddings to query and key tensors.
Args:
q: Query tensor of shape [batch, n_heads, seq_len, head_dim]
k: Key tensor of shape [batch, n_kv_heads, seq_len, head_dim]
seq_len: Sequence length
head_dim: Dimension of each head
rope_theta: Base value for RoPE (default 10000.0)
rope_scale: Scaling factor for positions (default 1.0)
Returns:
Tuple of transformed query and key tensors
"""
device = q.device
original_dtype = q.dtype
# Apply default values if None
theta = 10000.0 if rope_theta is None else rope_theta
scale = 1.0 if rope_scale is None else rope_scale
# Generate position indices
position = torch.arange(seq_len, device=device).float()
# Apply scaling factor to positions if provided
if scale != 1.0:
position = position / scale
# Create the frequency matrix - ensure stable computation in float32
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
# Compute the product of positions and frequencies
# Shape: [seq_len, head_dim/2]
freqs = torch.outer(position, inv_freq)
# Compute the rotation matrix elements: cos and sin
# Shape: [seq_len, head_dim/2]
emb = torch.cat((freqs, freqs), dim=-1)
# Ensure stable computation of sin/cos in float32
cos = torch.cos(emb).to(dtype=torch.float32)
sin = torch.sin(emb).to(dtype=torch.float32)
# Reshape for broadcasting
# Shape: [1, 1, seq_len, head_dim]
cos = cos.view(1, 1, seq_len, head_dim)
sin = sin.view(1, 1, seq_len, head_dim)
# Always compute in float32 for numerical stability
q_float = q.to(dtype=torch.float32)
k_float = k.to(dtype=torch.float32)
# For the even indices of the dimension
q_embed_even = q_float[..., 0::2]
q_embed_odd = q_float[..., 1::2]
k_embed_even = k_float[..., 0::2]
k_embed_odd = k_float[..., 1::2]
# Apply the rotation using the identities:
# q' = q * cos + rotate(q) * sin
# k' = k * cos + rotate(k) * sin
# where rotate(x) swaps the even and odd dimensions and negates the odd dimensions
q_rotated = torch.cat(
[
q_embed_even * cos[..., 0::2] - q_embed_odd * sin[..., 0::2],
q_embed_odd * cos[..., 1::2] + q_embed_even * sin[..., 1::2],
],
dim=-1,
)
k_rotated = torch.cat(
[
k_embed_even * cos[..., 0::2] - k_embed_odd * sin[..., 0::2],
k_embed_odd * cos[..., 1::2] + k_embed_even * sin[..., 1::2],
],
dim=-1,
)
# Convert back to the original dtype
return q_rotated.to(dtype=original_dtype), k_rotated.to(dtype=original_dtype)
def update_kv_cache(
key_states: torch.Tensor,
value_states: torch.Tensor,
@ -248,46 +162,6 @@ def update_kv_cache(
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
@torch.inference_mode()
def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q)
k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@torch.library.custom_op("attention::fused_mla_ref", mutates_args=())
def fused_mla_ref(
q_nope: torch.Tensor,
@ -325,23 +199,33 @@ def fused_mla_ref(
value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous()
if freqs_cis is not None:
cos = freqs_cis[0, ...]
sin = freqs_cis[1, ...]
for idx in range(seq_len.shape[0]):
(
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
) = apply_rotary_pos_emb_ds(
q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...],
cos_base = freqs_cis[0, ...]
sin_base = freqs_cis[1, ...]
for i in range(seq_len.shape[0]):
start = seq_start[i]
length = seq_len[i]
if q_len == 1:
idx = (input_pos[i] + length - 1).item()
pos_ids = torch.tensor(idx, device=cos_base.device)
else:
pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device)
cos = cos_base[pos_ids] # [..., 1, head_dim]
sin = sin_base[pos_ids]
q_slice = q_pe[start : start + length]
k_slice = k_pe[start : start + length]
q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q_slice,
k_slice,
cos,
sin,
torch.arange(input_pos[idx] + seq_len[idx])[-1]
if q_len == 1
else torch.arange(input_pos[idx] + seq_len[idx]),
-2,
)
q_pe[start : start + length] = q_rot
k_pe[start : start + length] = k_rot
query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d]
query_states[..., :qk_nope_head_dim] = q_nope
query_states[..., qk_nope_head_dim:] = q_pe
@ -454,7 +338,9 @@ def fused_mla(
k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1)
kv_seq_len = value_states.shape[-2]
q_pe, k_pe = apply_rotary_pos_emb_ds(q_pe, k_pe, cos, sin, position_ids)
cos = cos[position_ids]
sin = sin[position_ids]
q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin)
query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim)
query_states[:, :, :, :qk_nope_head_dim] = q_nope

View File

@ -0,0 +1,100 @@
from typing import Tuple
import torch
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=())
def torch_apply_rope_with_explicit_cos_sin(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reference PyTorch implementation of HF-style RoPE:
- Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and
[B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D]
- Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim].
"""
# in HF, cos/sin tensor are passed in as x.dtype, this is to double ensure
cos = cos.type_as(q)
sin = sin.type_as(q)
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@torch_apply_rope_with_explicit_cos_sin.register_fake
def torch_apply_rope_with_explicit_cos_sin_fake(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)
@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=())
def torch_apply_rope_with_complex_freqs(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # shape [B, S, head_dim//2]
unsqueeze_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Reference PyTorch implementation of interleaved (complex) RoPE:
- Input layout: [B, S, N, D] (interleaved)
- Frequencies are combined into a single complex-valued tensor `freqs_cis`
of shape [B, S, head_dim // 2].
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs = freqs_cis.unsqueeze(unsqueeze_dim)
xq_out = torch.view_as_real(xq_ * freqs).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
@torch_apply_rope_with_complex_freqs.register_fake
def torch_apply_rope_with_complex_freqs_fake(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # shape [B, S, head_dim//2]
unsqueeze_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(xq), torch.empty_like(xk)
@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=())
def torch_apply_rope_with_qk_interleaving(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
DeepSeek-style RoPE: interleaves Q/K channels and returns rotated (q_embed, k_embed).
- Input layout: [B, S, N, D] or [B*S, N, D] or [B, N, S, D]
- Frequencies are provided as separate `cos` and `sin` tensors of shape
[B, S, 1, D] or [B*S, 1, D] or [B, 1, S, D] matching input shape.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Rewrite below code to accept 3D input:
# b, h, s, d = q.shape
# q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
# b, h, s, d = k.shape
# k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q)
k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@torch_apply_rope_with_qk_interleaving.register_fake
def torch_apply_rope_with_qk_interleaving_fake(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)

View File

@ -1,12 +1,12 @@
from collections import defaultdict
from typing import Callable, Optional
from typing import Optional
import torch
from torch.fx import GraphModule, Node
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
from .._graph import canonicalize_graph
@ -36,7 +36,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule:
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
if not common_ancessor2:
continue
selected_experts = _bfs(
selected_experts = bfs(
common_ancessor2,
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
@ -153,26 +153,6 @@ def _insert_fused_moe_ops(gm: GraphModule):
graph.erase_node(node)
def _bfs(
node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None
) -> Node:
queue = [node]
visited = set()
while queue:
cur_node = queue.pop(0)
if boundary is not None and cur_node == boundary:
continue # Skip the boundary node.
if target(cur_node):
return cur_node
for next_node in getattr(cur_node, attr_next):
if boundary is not None and next_node == boundary:
continue # Do not expand past the boundary.
if next_node not in visited:
visited.add(next_node)
queue.append(next_node)
raise RuntimeError(f"Could not find node with target condition {target}.")
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
"""
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
@ -326,7 +306,7 @@ def _find_final_hidden_state_node(
For each expert output node (from the expert compute pattern), this function:
1. Retrieves a multiplication node from its users.
2. Extracts the second argument from the multiplication node (assumed to be the index node).
3. Uses a BFS (via _bfs) to locate the subsequent index_add_ node (guarded by the end_boundary).
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
After collecting all such index_add_ nodes, the final hidden state node is determined
as the one that is not used by any of the other index_add_ nodes.
@ -346,7 +326,7 @@ def _find_final_hidden_state_node(
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = _bfs(
index_add_node = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
@ -412,7 +392,7 @@ def _remove_dead_inplace_nodes_in_region(
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
try:
node_to_remove = _bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
graph.erase_node(node_to_remove)
return True

View File

@ -1,86 +1,141 @@
"""
This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used
to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`).
Supported RoPE variants:
1. Explicit Cos/Sin Multiplication (HF-style, e.g., LLaMA, Mixtral, Qwen)
- Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and
[B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D]
- Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim].
- Source code:
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
2. Complex Multiplication (GPTJ/Llama-stack-style, interleaved)
- Input layout: [B, S, N, D] (interleaved)
- Frequencies are combined into a single complex-valued tensor `freqs_cis` of shape [B, S, head_dim // 2].
- Source code:
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
Supported Minor variants:
- DeepSeekV3: reshape + transpose before applying RoPE.
dynamic position-based updates to frequency cache.
TODO: Support other variants:
- Phi-4: rotary applied only to part of the hidden dimension (q_rot, q_pass split).
- LLaMA4 Vision: 2D rotary frequencies constructed from image patches.
"""
import operator
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional
from typing import Any, DefaultDict, Dict, List, Optional, Sequence
import torch
from torch.fx import GraphModule, Node
from ...utils.logger import ad_logger
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op
from ...utils.node_utils import bfs, extract_output_tuple, identify_regions_between_residuals, is_op
from .._graph import canonicalize_graph
def match_rope_v1(gm: GraphModule) -> GraphModule:
def match_explicit_rope(gm: GraphModule) -> GraphModule:
"""
Identify and replace legacy RoPE subgraphs (explicit cos/sin multiplication pattern):
Identify and replace RoPE subgraphs (explicit cos/sin multiplication pattern):
output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
- HF-style: output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
- DS-style: requires interleaving Q/K before the cos/sin mul
If exactly two such branches (query and key) are detected within each region, they're replaced
by a call to `torch.ops.rope.flashinfer`.
by a call to `rope::torch_apply_rope_with_qk_interleaving` or
`rope::torch_apply_rope_with_explicit_cos_sin` respectively.
"""
ad_logger.info("Match explicit(HF) style RoPE")
graph = gm.graph
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
rope_position_ids_cache: Dict[str, Node] = {}
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
matches = []
matches = [] # list of (match_info, is_ds)
node = start_boundary
while node != end_boundary:
if is_op(node, torch.ops.aten.add):
match_info = _match_rotary_subpattern_V1(node)
if match_info:
matches.append(match_info)
explicit = _match_explicit_rope_subpattern(node)
if explicit is not None:
raw = explicit["raw_input"]
# check if this raw is result of DS interleave
inter = _match_input_interleave_pattern(raw)
if inter is not None:
ds_match = {
"raw_input": inter["interleaved"],
"unsqueeze_cos": explicit["unsqueeze_cos"],
"unsqueeze_sin": explicit["unsqueeze_sin"],
"add_node": explicit["add_node"],
}
matches.append((ds_match, True))
else:
matches.append((explicit, False))
node = node.next
if not matches:
continue
if len(matches) != 2:
raise RuntimeError(
ad_logger.warning(
f"Expected exactly 2 legacy RoPE branches between {start_boundary} and {end_boundary}, "
f"found {len(matches)}."
)
continue
# Assume the first matched branch is query (q), second is key (k).
# This assumption is based on the default ordering in the exported graph,
# since node naming conventions don't reliably indicate q/k branches.
q_match, k_match = matches
_process_rope_v1(
graph,
q_match,
k_match,
start_boundary,
rope_flash_cache,
rope_position_ids_cache,
)
(q_match, q_is_ds), (k_match, k_is_ds) = matches
if q_is_ds != k_is_ds:
ad_logger.warning("Mismatched RoPE types between q and k branches")
continue
if q_is_ds:
_process_input_interleave_rope(graph, q_match, k_match)
else:
_process_explicit_rope(graph, q_match, k_match, start_boundary)
gm = canonicalize_graph(gm)
return gm
def match_rope_v2(gm: GraphModule) -> GraphModule:
def match_complex_rope(gm: GraphModule) -> GraphModule:
"""
Identify and replace RoPE subgraphs using complex multiplication pattern:
output = type_as(flatten(view_as_real(mul(view_as_complex(reshape(to_dtype(x))), unsqueeze(freqs_cis, 2)))), x)
If exactly two such branches (query and key) are detected within each region, they're replaced
by a call to `torch.ops.rope.flashinfer`.
by a call to `torch.ops.rope.torch_apply_rope_with_complex_freqs`.
"""
ad_logger.info("Match Complex style RoPE")
graph = gm.graph
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
rope_position_ids_cache: Dict[str, Node] = {}
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
matches = []
node = start_boundary
while node != end_boundary:
if is_op(node, torch.ops.aten.type_as):
match_info = _match_rotary_subpattern_V2(node)
match_info = _match_complex_rope_subpattern(node)
if match_info:
matches.append(match_info)
node = node.next
@ -88,28 +143,344 @@ def match_rope_v2(gm: GraphModule) -> GraphModule:
if not matches:
continue
if len(matches) != 2:
raise RuntimeError(
ad_logger.warning(
f"Expected exactly 2 complex RoPE branches between {start_boundary} and {end_boundary}, "
f"found {len(matches)}."
)
continue
# Assume the first matched branch is query (q), second is key (k).
# This assumption is based on the default ordering in the exported graph,
# since node naming conventions don't reliably indicate q/k branches.
q_match, k_match = matches
_process_rope_v2(
graph,
q_match,
k_match,
rope_flash_cache,
rope_position_ids_cache,
)
_process_complex_rope(graph, q_match, k_match)
gm = canonicalize_graph(gm)
return gm
def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
def _get_default_unsqueeze_dim(op):
schema = next(iter(op._schemas.values()))
for a in schema.arguments:
if a.name == "unsqueeze_dim" and a.has_default_value:
return a.default_value
raise RuntimeError(f"No default unsqueeze_dim on {op}")
def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphModule:
"""
Match and transform input and output of rope ops to the layout specified to meet requirements of optimized ops.
Supported layout is 'bsnd' (batch, seq, head, dim).
"""
supported = {"bsnd", "bnsd"}
if expected_layout.lower() not in supported:
ad_logger.warning(
f"Unsupported RoPE layout '{expected_layout}'; expected '{supported}'. Skipping RoPE layout matching."
)
return gm
ad_logger.info(f"Match RoPE layout to {expected_layout}")
graph = gm.graph
rope_ops = {
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
torch.ops.rope.torch_apply_rope_with_complex_freqs,
}
need_transpose = False
need_canonicalize_graph = False
for node in graph.nodes:
if not is_op(node, rope_ops):
continue
rope_op = next(op for op in rope_ops if is_op(node, op))
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
q_node, k_node, freqs_node, *rest = node.args
unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op)
else:
q_node, k_node, cos_node, sin_node, *rest = node.args
unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op)
if unsq == 2:
current_layout = "bsnd"
elif unsq == 1:
current_layout = "bnsd"
else:
ad_logger.warning(
"Unsqueeze_dim is not one of [1, 2]. "
"Unable to infer layout of q node. Skip layout matching"
)
continue
need_transpose = expected_layout.lower() != current_layout
if not need_transpose:
continue
need_canonicalize_graph = True
# retrieve q and k output node from node
q_rope_old, k_rope_old = extract_output_tuple(node, 2)
if q_rope_old is None or k_rope_old is None:
ad_logger.warning(
f"Failed to extract all two outputs from the explicit op, \
get {q_rope_old}, {k_rope_old}, fail to match rope layout with {node} with"
)
continue
ad_logger.debug(
f"Inferred RoPE input layout: '{current_layout}']Mapping layout to '{expected_layout}']"
)
with graph.inserting_before(node):
q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2))
k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2))
q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,))
k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,))
q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2)
k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2)
if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
new_args = (
q_for_op_contig,
k_for_op_contig,
freqs_node,
2 if expected_layout.lower() == "bsnd" else 1,
) # unsqueeze_dim updated
else:
new_args = (
q_for_op_contig,
k_for_op_contig,
cos_node,
sin_node,
2 if expected_layout.lower() == "bsnd" else 1,
) # unsqueeze_dim updated
node.args = new_args
with graph.inserting_after(q_rope_old):
q_rope_new = graph.call_function(torch.ops.aten.transpose, args=(q_rope_old, 1, 2))
with graph.inserting_after(k_rope_old):
k_rope_new = graph.call_function(torch.ops.aten.transpose, args=(k_rope_old, 1, 2))
# Preserve fake tensor in meta["val"] for the transposed inputs
q_rope_new.meta["val"] = q_rope_old.meta["val"]
q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2)
k_rope_new.meta["val"] = k_rope_old.meta["val"]
k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2)
q_rope_old.replace_all_uses_with(q_rope_new)
k_rope_old.replace_all_uses_with(k_rope_new)
q_rope_new.args = (q_rope_old, 1, 2)
k_rope_new.args = (k_rope_old, 1, 2)
if need_canonicalize_graph:
gm = canonicalize_graph(gm)
return gm
def optimize_rope(gm: GraphModule) -> GraphModule:
"""
Scan the FX graph and replace calls to the torch-reference RoPE ops with
the optimized `rope::flashinfer` kernel.
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
"""
ad_logger.info("RoPE optimization")
graph = gm.graph
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
rope_position_ids_cache: Dict[str, Node] = {}
for node in list(graph.nodes):
if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin):
_optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache)
elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs):
_optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache)
gm = canonicalize_graph(gm)
return gm
def _optimize_explicit(
graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
) -> None:
# node.args may be (q, k, cos, sin) or (q, k, cos, sin, unsq)
q_node, k_node, cos_node, sin_node, *rest = node.args
# retrieve q and k output node from node
q_rope_old, k_rope_old = extract_output_tuple(node, 2)
if q_rope_old is None or k_rope_old is None:
ad_logger.warning(
f"Failed to extract all two outputs from the explicit op, \
get {q_rope_old}, {k_rope_old}, fail to replace {node} with flashinfer rope"
)
return
# Sanity check on head_dim
if not _validate_rope_inputs(q_node, k_node):
return
# Sanity check that input layout is BSND (no transpose needed).
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
ad_logger.warning(
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
)
return
elif q_fake is not None:
ad_logger.warning(
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
)
return
head_dim = cos_node.meta["val"].shape[-1]
half_head_dim = head_dim // 2
cache_key = (cos_node, sin_node)
if cache_key in cache:
fused_cos_sin_to = cache[cache_key]
else:
with graph.inserting_after(cos_node):
cos_prefix = graph.call_function(
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
)
with graph.inserting_after(sin_node):
sin_prefix = graph.call_function(
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
)
with graph.inserting_after(sin_prefix):
fused_cos_sin = graph.call_function(
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
)
with graph.inserting_after(q_node):
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0))
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1))
with graph.inserting_after(_get_last_node([sym_batch, sym_seq])):
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
with graph.inserting_after(_get_last_node([bs_seq, fused_cos_sin])):
fused_cos_sin_flat = graph.call_function(
torch.ops.aten.view, args=(fused_cos_sin, (bs_seq, -1))
)
with graph.inserting_after(fused_cos_sin_flat):
fused_cos_sin_to = graph.call_function(
torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32)
)
cache[cache_key] = fused_cos_sin_to
with graph.inserting_before(node):
position_ids = _get_position_ids(
graph,
q_node,
batch_dim=0,
seq_dim=1,
rope_position_ids_cache=pos_cache,
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_node, k_node, position_ids, fused_cos_sin_to, True),
)
with graph.inserting_after(flash_node):
q_rope_new = graph.call_function(operator.getitem, args=(flash_node, 0))
k_rope_new = graph.call_function(operator.getitem, args=(flash_node, 1))
q_rope_new.meta["val"] = q_rope_old.meta.get("val", None)
k_rope_new.meta["val"] = k_rope_old.meta.get("val", None)
q_rope_old.replace_all_uses_with(q_rope_new)
k_rope_old.replace_all_uses_with(k_rope_new)
graph.erase_node(q_rope_old)
graph.erase_node(k_rope_old)
def _optimize_complex(
graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node]
) -> None:
q_node, k_node, inv_freq_node = node.args
# Sanity check on head_dim
if not _validate_rope_inputs(q_node, k_node):
return
# Sanity check that input layout is BSND (no transpose needed).
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
ad_logger.warning(
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
)
return
elif q_fake is not None:
ad_logger.warning(
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
)
return
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
if inv_freq_node in cache:
cos_sin_flash = cache[inv_freq_node]
else:
# Compute the fused cosine/sine cache.
with graph.inserting_after(inv_freq_node):
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
with graph.inserting_after(real_part):
cos_sin_flash_3d = graph.call_function(
torch.ops.aten.cat, args=((real_part, imag_part), -1)
)
with graph.inserting_after(q_node):
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0))
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1))
with graph.inserting_after(_get_last_node([sym_batch, sym_seq])):
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
with graph.inserting_after(_get_last_node([bs_seq, cos_sin_flash_3d])):
fused_cos_sin_flat = graph.call_function(
torch.ops.aten.view, args=(cos_sin_flash_3d, (bs_seq, -1))
)
with graph.inserting_after(fused_cos_sin_flat):
cos_sin_flash = graph.call_function(
torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32)
)
cache[inv_freq_node] = cos_sin_flash
with graph.inserting_before(node):
position_ids = _get_position_ids(
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_node, k_node, position_ids, cos_sin_flash, False),
)
flash_node.meta["val"] = node.meta.get("val", None)
node.replace_all_uses_with(flash_node)
graph.erase_node(node)
def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]:
"""
Detect DeepSeek-style interleave on Q/K:
reshape(transpose(view(raw, [b,h,s,d//2,2]), 4, 3), [b,h,s,d])
Returns:
{"interleaved": raw_node} if matched, else None.
"""
if not is_op(node, torch.ops.aten.reshape):
return None
transpose_node = node.args[0]
if not is_op(transpose_node, torch.ops.aten.transpose):
return None
view_node = transpose_node.args[0]
if not is_op(view_node, torch.ops.aten.view):
return None
raw_node = view_node.args[0]
if not isinstance(raw_node, Node):
return None
return {"interleaved": raw_node}
def _match_explicit_rope_subpattern(add_node: Node) -> Optional[Dict[str, Node]]:
"""
Given an aten.add.Tensor node that is expected to compute:
output = (raw_input * unsqueeze(cos)) + (rotate_half(raw_input) * unsqueeze(sin))
@ -195,7 +566,7 @@ def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
}
def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]:
def _match_complex_rope_subpattern(type_as_node: Node) -> Optional[Dict[str, Node]]:
"""
Given a type_as node, this function inspects the graph
structure and returns a dictionary with:
@ -245,9 +616,10 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]
else:
return None
# Verify that the unsqueeze is performed along dimension 2.
if not (len(unsqueeze_node.args) >= 2 and unsqueeze_node.args[1] == 2):
if not (len(unsqueeze_node.args) >= 2):
return None
unsqueeze_dim = unsqueeze_node.args[1]
inv_freq_candidate = unsqueeze_node.args[0]
# Match the view_as_complex branch.
@ -273,27 +645,27 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]
"input": input_tensor,
"inv_freq": inv_freq_candidate,
"out": type_as_node,
"unsqueeze_dim": unsqueeze_dim,
}
def _process_rope_v1(
graph: GraphModule,
def _process_explicit_rope(
graph: GraphModule.graph,
q_match: Dict[str, Node],
k_match: Dict[str, Node],
start_boundary: Node,
rope_flash_cache: DefaultDict[Any, Optional[Node]],
rope_position_ids_cache: Dict[str, Node],
) -> None:
"""
Process a region that matched the legacy RoPE pattern (v1).
Inserts the custom op (flashinfer) and replaces the original add nodes.
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
Replace matched Explicit RoPE subgraph with `rope::torch_apply_rope_with_explicit_cos_sin`.
"""
q_node = q_match["raw_input"]
k_node = k_match["raw_input"]
cos_node = q_match["unsqueeze_cos"].args[0]
sin_node = q_match["unsqueeze_sin"].args[0]
cos_unsq = q_match["unsqueeze_cos"]
sin_unsq = q_match["unsqueeze_sin"]
cos_node = cos_unsq.args[0]
sin_node = sin_unsq.args[0]
unsq_dim = cos_unsq.args[1]
add_node = q_match["add_node"]
# Sanity-check: ensure cos/sin nodes trace back to aten.cos/aten.sin.
bfs(
@ -309,167 +681,197 @@ def _process_rope_v1(
boundary=start_boundary,
)
# Infer input layout; default to [b, n, s, d] if inference fails.
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
need_transpose = isinstance(q_fake.shape[1], int)
ad_logger.debug(
f"Inferred RoPE input layout: [{'[b, n, s, d]' if need_transpose else '[b, s, n, d]'}]"
)
# Additional sanity check for the third dimension
if need_transpose:
if not isinstance(q_fake.shape[2], torch.SymInt):
ad_logger.warning(
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
)
need_transpose = True
else:
if not isinstance(q_fake.shape[1], torch.SymInt):
ad_logger.warning(
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
)
need_transpose = True
else:
ad_logger.warning("Unable to infer layout of q node. Defaulting to [b, n, s, d].")
need_transpose = True
with graph.inserting_before(q_match["add_node"]):
if need_transpose:
q_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(q_node, 1, 2))
k_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(k_node, 1, 2))
q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,))
k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,))
else:
q_for_op_contig, k_for_op_contig = q_node, k_node
head_dim = cos_node.meta["val"].shape[-1]
half_head_dim = head_dim // 2
cache = rope_flash_cache
cache_key = (cos_node, sin_node)
if cache_key in cache:
fused_cos_sin = cache[cache_key]
else:
cos_prefix = graph.call_function(
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
)
sin_prefix = graph.call_function(
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
)
fused_cos_sin = graph.call_function(
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
)
fused_cos_sin = graph.call_function(operator.getitem, args=(fused_cos_sin, 0))
fused_cos_sin = graph.call_function(
torch.ops.aten.to.dtype, args=(fused_cos_sin, torch.float32)
)
cache[cache_key] = fused_cos_sin
position_ids = _get_position_ids(
graph,
q_for_op_contig,
batch_dim=0,
seq_dim=1,
rope_position_ids_cache=rope_position_ids_cache,
with graph.inserting_before(add_node):
rope_node = graph.call_function(
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
args=(q_node, k_node, cos_node, sin_node, unsq_dim),
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_for_op_contig, k_for_op_contig, position_ids, fused_cos_sin, True),
with graph.inserting_after(rope_node):
out_q = graph.call_function(operator.getitem, args=(rope_node, 0))
out_k = graph.call_function(operator.getitem, args=(rope_node, 1))
out_q.meta["val"] = add_node.meta.get("val", None)
out_k.meta["val"] = k_match["add_node"].meta.get("val", None)
q_match["add_node"].replace_all_uses_with(out_q)
k_match["add_node"].replace_all_uses_with(out_k)
def _process_complex_rope(
graph: GraphModule.graph,
q_match: Dict[str, Node],
k_match: Dict[str, Node],
) -> None:
"""
Replace matched Complex RoPE subgraph with `rope::torch_apply_rope_with_complex_freqs`.
"""
xq = q_match["input"]
xk = k_match["input"]
inv = q_match["inv_freq"]
usdim = q_match["unsqueeze_dim"]
out_node = q_match.get("out")
if inv != k_match["inv_freq"]:
ad_logger.warning(
"Mismatch of freqs_cis (inv_freq) between branches. Fail to match complex rope pattern"
)
return
with graph.inserting_before(out_node):
rope_node = graph.call_function(
torch.ops.rope.torch_apply_rope_with_complex_freqs,
args=(xq, xk, inv, usdim),
)
with graph.inserting_after(flash_node):
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
with graph.inserting_after(rope_node):
out_q = graph.call_function(operator.getitem, args=(rope_node, 0))
out_k = graph.call_function(operator.getitem, args=(rope_node, 1))
if need_transpose:
with graph.inserting_after(raw_q):
new_q = graph.call_function(torch.ops.aten.transpose.int, args=(raw_q, 1, 2))
with graph.inserting_after(raw_k):
new_k = graph.call_function(torch.ops.aten.transpose.int, args=(raw_k, 1, 2))
else:
new_q, new_k = raw_q, raw_k
out_q.meta["val"] = out_node.meta.get("val", None)
out_k.meta["val"] = k_match["out"].meta.get("val", None)
new_q.meta["val"] = q_match["add_node"].meta.get("val", None)
new_k.meta["val"] = k_match["add_node"].meta.get("val", None)
q_match["add_node"].replace_all_uses_with(new_q)
k_match["add_node"].replace_all_uses_with(new_k)
out_node.replace_all_uses_with(out_q)
k_match["out"].replace_all_uses_with(out_k)
def _process_rope_v2(
def _process_input_interleave_rope(
graph: GraphModule,
q_match: Dict[str, Node],
k_match: Dict[str, Node],
rope_flash_cache: DefaultDict[Any, Optional[Node]],
rope_position_ids_cache: Dict[str, Node],
) -> None:
"""
Process a region that matched the complex-multiplication RoPE pattern (v2).
Inserts the custom op (flashinfer) after extracting frequency information,
and replaces the original type_as nodes.
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
Replace a matched DS-style RoPE subgraph with a call to rope::torch_apply_rope_with_qk_interleaving.
Cache the one-time unsqueeze of cos/sin.
"""
q_node = q_match["input"]
k_node = k_match["input"]
inv_freq_node = q_match["inv_freq"]
q_node = q_match["raw_input"]
k_node = k_match["raw_input"]
cos_node = q_match["unsqueeze_cos"].args[0]
sin_node = q_match["unsqueeze_sin"].args[0]
# A patch for the case when q_output appears before k_input in the graph
# Move q_output down right before its first user so that graph remains in
# topological order after inserting the apply rope custom op
q_match["add_node"] = _move_node_before_first_user(q_match["add_node"])
if inv_freq_node != k_match["inv_freq"]:
raise RuntimeError("Mismatch of freqs_cis (inv_freq) between branches.")
# Infer unsqueeze_dim from layout
unsq_dim = 1
fake = q_node.meta.get("val", None)
if fake is not None and len(fake.shape) == 4:
# if shape[1] symbolic, it's [B, S, N, D] => BSND -> head dim is 2
if isinstance(fake.shape[1], torch.SymInt):
unsq_dim = 2
else:
unsq_dim = 1
# Sanity check that input layout is BSND (no transpose needed).
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
with graph.inserting_after(_get_last_node([q_node, k_node, cos_node, sin_node])):
ds_node = graph.call_function(
torch.ops.rope.torch_apply_rope_with_qk_interleaving,
args=(q_node, k_node, cos_node, sin_node, unsq_dim),
)
with graph.inserting_after(ds_node):
q_out = graph.call_function(operator.getitem, args=(ds_node, 0))
k_out = graph.call_function(operator.getitem, args=(ds_node, 1))
q_out.meta["val"] = q_match["add_node"].meta.get("val", None)
k_out.meta["val"] = k_match["add_node"].meta.get("val", None)
q_match["add_node"].replace_all_uses_with(q_out)
k_match["add_node"].replace_all_uses_with(k_out)
def _move_node_before_first_user(node: Node) -> Node:
"""
Remove `node` from the graph and re-insert a clone of it immediately
before its earliest user. Returns the new node.
If `node` has no users, or is already right before its first user,
this is a no-op and returns the original node.
"""
graph = node.graph
ordering = list(graph.nodes)
users = list(node.users)
if not users:
return node
# locate the earliest user in the current ordering
first_user = min(users, key=lambda u: ordering.index(u))
if ordering.index(node) == ordering.index(first_user) - 1:
return node
with graph.inserting_before(first_user):
new_node = graph.node_copy(node, lambda n: n)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
return new_node
def _get_last_node(nodes: Sequence[Node]) -> Node:
"""
Given a list of FX Nodes,
return the one that appears last in the graph's execution order.
"""
if not nodes:
raise ValueError("`nodes` must be a non-empty sequence of FX Node objects")
graph = nodes[0].graph
ordering = list(graph.nodes)
# Sanity check that all nodes are in same graph
valid = [n for n in nodes if n in ordering]
if not valid:
raise ValueError("None of the provided nodes belong to the same graph")
last = max(valid, key=lambda n: ordering.index(n))
return last
def _validate_rope_inputs(q_node: Node, k_node: Node) -> bool:
"""
Validates that:
- The last dimension (head_dim) of both q and k is a multiple of 64.
- The dtype of q and k is half precision (bfloat16 or float16).
- Layout should be [B,S,N,D] (dim 1 should be symbolic)
"""
for name, node in [("q", q_node), ("k", k_node)]:
fake_val = node.meta.get("val", None)
if fake_val is None:
ad_logger.warning(
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
f"Meta['val'] for {name} not available; skipping RoPE transformation."
)
else:
ad_logger.warning(
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
)
return False
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
cache = rope_flash_cache
if inv_freq_node in cache:
cos_sin_flash = cache[inv_freq_node]
else:
# Compute the fused cosine/sine cache.
with graph.inserting_after(inv_freq_node):
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
with graph.inserting_after(real_part):
cos_sin_flash_3d = graph.call_function(
torch.ops.aten.cat, args=((real_part, imag_part), -1)
# Check dtype
if fake_val.dtype not in (torch.float16, torch.bfloat16):
ad_logger.warning(
f"""{name} tensor is {fake_val.dtype},
expected half precision (float16 or bfloat16). Skipping RoPE transformation."""
)
with graph.inserting_after(cos_sin_flash_3d):
cos_sin_flash = graph.call_function(operator.getitem, args=(cos_sin_flash_3d, 0))
with graph.inserting_after(cos_sin_flash):
cos_sin_flash = graph.call_function(
torch.ops.aten.to.dtype, args=(cos_sin_flash, torch.float32)
return False
# Check head_dim
if len(fake_val.shape) < 1:
ad_logger.warning(f"{name} tensor has invalid shape {fake_val.shape}.")
return False
head_dim = fake_val.shape[-1]
if isinstance(head_dim, int) and head_dim % 64 != 0:
ad_logger.warning(
f"{name} head_dim = {head_dim} is not a multiple of 64. Skipping RoPE transformation."
)
cache[inv_freq_node] = cos_sin_flash
return False
with graph.inserting_before(q_match["out"]):
position_ids = _get_position_ids(
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=rope_position_ids_cache
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_node, k_node, position_ids, cos_sin_flash, False),
)
# Check shape
if not isinstance(fake_val.shape[1], torch.SymInt):
ad_logger.warning(
f"{name} has shape {fake_val.shape} that is not supported. Only support [B, S, N, D] layout.\
Skipping RoPE transformation."
)
return False
with graph.inserting_after(flash_node):
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
raw_q.meta["val"] = q_match["out"].meta.get("val", None)
raw_k.meta["val"] = k_match["out"].meta.get("val", None)
q_match["out"].replace_all_uses_with(raw_q)
k_match["out"].replace_all_uses_with(raw_k)
return True
def _get_position_ids(
@ -491,22 +893,17 @@ def _get_position_ids(
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, batch_dim))
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, seq_dim))
bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq))
# Retrieve device information, ensuring it is a torch.device.
device = q_node.meta.get("device", "cpu")
if isinstance(device, str):
device = torch.device(device)
# Build positions: arange(sym_seq) -> view -> expand -> flatten.
positions_node = graph.call_function(
position_ids = graph.call_function(
torch.ops.aten.arange,
args=(sym_seq,),
args=(bs_seq,),
kwargs={"dtype": torch.float32, "device": device, "pin_memory": False},
)
positions_node = graph.call_function(torch.ops.aten.view, args=(positions_node, (1, -1)))
positions_node = graph.call_function(
torch.ops.aten.expand, args=(positions_node, (sym_batch, -1))
)
position_ids = graph.call_function(torch.ops.aten.flatten, args=(positions_node,))
rope_position_ids_cache["position_ids"] = position_ids
return position_ids

View File

@ -25,12 +25,14 @@ from .library import (
insert_cached_attention,
match_attention_layout,
match_causal_attn_mask,
match_complex_rope,
match_eager_attention,
match_explicit_rope,
match_grouped_attention,
match_moe_pattern,
match_repeat_kv,
match_rope_v1,
match_rope_v2,
match_rope_layout,
optimize_rope,
quantize,
resize_kv_cache,
)
@ -122,10 +124,10 @@ class InferenceOptimizer:
egm = match_attention_layout(egm, self.attention_op)
# Match rope
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
egm = match_rope_v1(egm)
egm = match_rope_v2(egm)
egm = match_explicit_rope(egm)
egm = match_complex_rope(egm)
# Match RoPE layout expected by our backend
egm = match_rope_layout(egm, self.attention_op.get_attention_layout())
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
@ -134,6 +136,10 @@ class InferenceOptimizer:
# eliminate redundant transpose operations
egm = eliminate_redundant_transposes(egm)
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
egm = optimize_rope(egm)
# run TP sharding across ranks
egm = column_row_shard(egm, local_rank, world_size)

View File

@ -1,5 +1,6 @@
"""Common utils for torch fx graph transformation."""
import operator
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional, Tuple, Union
@ -320,3 +321,22 @@ def bfs(
visited.add(next_node)
queue.append(next_node)
raise RuntimeError(f"Could not find node with target condition {target}.")
def extract_output_tuple(node: Node, count: int = 2):
"""
Extract up to `count` outputs from a tuple-producing node.
Returns a list of length `count`, with None if an output isn't found.
"""
results = []
for idx in range(count):
user_node = next(
(
u
for u in node.users
if u.op == "call_function" and u.target == operator.getitem and u.args[1] == idx
),
None,
)
results.append(user_node)
return results

View File

@ -51,7 +51,7 @@ def run_test(
num_params_gm = count_parameters(gm)
assert num_params_model == num_params_gm
assert all_close(y_model, y_gm, atol=atol, rtol=rtol)
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
# graph transformation + check
gm_transformed = transform(gm, *args)
@ -71,9 +71,7 @@ def run_test(
if strict_loading:
# check if output equals without loading state dict
assert all_close(y_model, y_transformed, atol=atol, rtol=rtol), (
f"{y_model=}, {y_transformed=}"
)
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)
if test_load_hook:
# check if loading hook works from original state dict
@ -83,9 +81,7 @@ def run_test(
gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False)
y_loaded_from_original = gm_transformed(x)
assert all_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol), (
f"{y_model=}, {y_loaded_from_original=}"
)
torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol)
# check if loading hook works from state_dict of a transformed model
state_dict_sharded = copy.deepcopy(gm_transformed.state_dict())
@ -95,9 +91,7 @@ def run_test(
gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False)
y_loaded_from_transformed = gm_transformed(x)
assert all_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol), (
f"{y_model=}, {y_loaded_from_transformed=}"
)
torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol)
# check if we can still export the model as expected
torch_export(gm_transformed, args=(x,))

View File

@ -222,3 +222,55 @@ def _hf_model_dir_or_hub_id(
return hf_model_dir
else:
return hf_hub_id
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb_explicit(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def apply_rotary_pos_emb_complex(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
unsqueeze_dim: int = 2,
):
# Reshape the inputs to pair the last dimension.
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
freqs = freqs_cis.unsqueeze(unsqueeze_dim)
xq_out = torch.view_as_real(xq_complex * freqs).flatten(3)
xk_out = torch.view_as_real(xk_complex * freqs).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L339
def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""
Apply rotary positional embeddings by interleaving Q/K ,
indexing cos/sin tables with position_ids, and returning rotated q, k.
cos: [seq_len, head_dim]
sin: [seq_len, head_dim]
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

View File

@ -1,145 +0,0 @@
from typing import Tuple
import flashinfer
import pytest
import torch
import tensorrt_llm._torch.auto_deploy # noqa: F401
torch.manual_seed(0)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_and_custom_rope_ops(dtype, atol, rtol, head_dim):
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
# Prepare rotary embedding values.
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
# For HF and the custom op: duplicated layout [seq_len, head_dim].
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
# Direct FlashInfer kernel call.
query_flat = query.view(batch * seq_len, n_head * head_dim)
key_flat = key.view(batch * seq_len, n_head * head_dim)
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
)
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
# HF implementation using apply_rotary_pos_emb.
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
q_for_hf = query.transpose(1, 2).clone()
k_for_hf = key.transpose(1, 2).clone()
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
q_hf, k_hf = apply_rotary_pos_emb(q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1)
# Convert outputs to [batch, seq_len, n_head, head_dim]
q_hf = q_hf.transpose(1, 2).to(dtype)
k_hf = k_hf.transpose(1, 2).to(dtype)
# Custom op call
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, True)
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
# Version 2: complex multiplication approach
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-5, 1e-5),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_complex_rotary(dtype, atol, rtol, head_dim):
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
out_q_v2, out_k_v2 = apply_rotary_emb(query, key, freqs_cis)
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
# q/k of llama4 rope is interleaved
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, False)
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)

View File

@ -0,0 +1,294 @@
import flashinfer
import pytest
import torch
from _model_test_utils import (
apply_rotary_pos_emb_complex,
apply_rotary_pos_emb_ds,
apply_rotary_pos_emb_explicit,
)
import tensorrt_llm._torch.auto_deploy # noqa: F401
torch.manual_seed(0)
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim):
"""
Verify FlashInfer's Neox RoPE kernel against HF's apply_rotary_pos_emb:
- Q/K: [B, S, N, D] non-interleaved half-precision.
- cos_sin_cache: [S, D] = [cos||sin] concatenated.
- HF path: Q/K [B, N, S, D], cos_new/sin_new: [S, D] duplicated, then broadcast to [B, S, D].
"""
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
# Prepare rotary embedding values.
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
cos_sin_cache_expand = (
cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1)
) # [batch * seq_len, head_dim]
# For HF and the custom op: duplicated layout [seq_len, head_dim].
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
# Direct FlashInfer kernel call.
query_flat = query.view(batch * seq_len, n_head * head_dim)
key_flat = key.view(batch * seq_len, n_head * head_dim)
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
)
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
# HF implementation using apply_rotary_pos_emb.
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
q_for_hf = query.transpose(1, 2).clone()
k_for_hf = key.transpose(1, 2).clone()
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
q_hf, k_hf = apply_rotary_pos_emb_explicit(
q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1
)
# Convert outputs to [batch, seq_len, n_head, head_dim]
q_hf = q_hf.transpose(1, 2).to(dtype)
k_hf = k_hf.transpose(1, 2).to(dtype)
# Custom op call
positions_flat = torch.arange(batch * seq_len, device=device)
custom_q, custom_k = torch.ops.rope.flashinfer(
query, key, positions_flat, cos_sin_cache_expand, True
)
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-5, 1e-5),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim):
"""
Check FlashInfer's RoPE matches the complex-multiplication approach:
- Q/K: [B, S, N, D] non-interleaved half-precision.
- freqs_cis: [B, S, D/2] complex polar values.
- flashinfer uses cos_sin_cache: [S, D] interleaved from real/imag of freqs_cis.
"""
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
out_q_v2, out_k_v2 = apply_rotary_pos_emb_complex(query, key, freqs_cis)
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
cos_sin_cache_expand = (
cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1)
) # [batch * seq_len, head_dim]
# q/k of llama4 rope is interleaved
positions_flat = torch.arange(batch * seq_len, device=device)
custom_q, custom_k = torch.ops.rope.flashinfer(
query, key, positions_flat, cos_sin_cache_expand, False
)
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)
# Copy of TritonWithFlattenedInputs._precompute_freqs_cis
def precompute_freqs_cis_interleaved(
seq_len: int, head_dim: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""
Precompute interleaved cosine and sine frequency cache for rotary position embeddings (RoPE).
Returns a tensor of shape [seq_len, head_dim//2, 2], where the last dimension
alternates [cos, sin] values for each rotary frequency.
cache[s, i, 0] == cos(position=s · inv_freq[i])
cache[s, i, 1] == sin(position=s · inv_freq[i]).
"""
inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
t = torch.arange(seq_len, device=device)
angles = t.unsqueeze(1) * inv_freq.unsqueeze(0)
freqs_cis = torch.polar(torch.ones_like(angles), angles)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype)
@pytest.mark.parametrize("layout", ["bsnd", "bnsd"])
@pytest.mark.parametrize("head_dim", [64, 256])
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"],
)
def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol):
"""
Validate custom Triton apply_rope_with_input_pos against HF's apply_rotary_pos_emb:
- Q/K: layout 'bsnd'[B,S,N,D] or 'bnsd'[B,N,S,D], non-interleaved half-precision.
- cosin_cache: [S, D/2, 2] interleaved [cos,sin].
- HF path: cos_full/sin_full: [S, D] then expanded to [B, S, D].
"""
device = "cuda"
batch, seq_len, n_head = 2, 4, 3
# build cache and per-batch zero positions
cosin_cache = precompute_freqs_cis_interleaved(seq_len, head_dim, dtype, device) # [S, D/2, 2]
positions = torch.zeros(batch, dtype=torch.int32, device=device)
if layout == "bsnd":
q = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
k = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
unsq = 2
else: # "bnsd"
q = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device)
k = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device)
unsq = 1
# build HF float32 cos/sin full tensors
cos_f32 = cosin_cache[..., 0].to(torch.float32) # [S, H/2]
sin_f32 = cosin_cache[..., 1].to(torch.float32) # [S, H/2]
cos_full = torch.cat([cos_f32, cos_f32], dim=1) # [S, H]
sin_full = torch.cat([sin_f32, sin_f32], dim=1) # [S, H]
cos_exp = cos_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H]
sin_exp = sin_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H]
# HF reference in float32, then cast back
q_f32, k_f32 = apply_rotary_pos_emb_explicit(
q.to(torch.float32), k.to(torch.float32), cos_exp, sin_exp, unsqueeze_dim=unsq
)
q_hf = q_f32.to(dtype)
k_hf = k_f32.to(dtype)
q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout)
k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout)
torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol)
torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol)
def inverse_interleave_permute_for_rotary(x: torch.Tensor) -> torch.Tensor:
b, h, s, d = x.shape
x = x.view(b, h, s, 2, d // 2)
x = x.transpose(4, 3)
return x.reshape(b, h, s, d)
@pytest.mark.parametrize("head_dim", [64, 256])
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"],
)
def test_ds_impl_and_hf_impl(dtype, head_dim, atol, rtol):
"""
Ensure Deepseek's interleaved-Q/K RoPE matches HF apply_rotary_pos_emb:
- DS Q/K: [B, N, S, D] channel-interleaved in last dim.
- cos_new/sin_new: [S, D] duplicated real values.
- HF path: Q/K [B,N,S,D], cos_expand/sin_expand: [B,S,D], unsqueezed at dim=1.
"""
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch, seq_len)
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
# duplicate to shape [seq_len, head_dim]
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
# HF torch expects inputs of shape [B, N, S, D]
q_for_hf = query.transpose(1, 2).clone()
k_for_hf = key.transpose(1, 2).clone()
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
q_rotated_hf, k_rotated_hf = apply_rotary_pos_emb_explicit(
q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1
)
q_rotated_hf = q_rotated_hf.transpose(1, 2).to(torch.float32)
k_rotated_hf = k_rotated_hf.transpose(1, 2).to(torch.float32)
q_for_hf2 = inverse_interleave_permute_for_rotary(q_for_hf.clone())
k_for_hf2 = inverse_interleave_permute_for_rotary(k_for_hf.clone())
# adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L134
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos_ds = emb.cos()
sin_ds = emb.sin()
torch.testing.assert_close(cos_ds, cos_new, rtol=rtol, atol=atol)
torch.testing.assert_close(sin_ds, sin_new, rtol=rtol, atol=atol)
q_rotated_hf2, k_rotated_hf2 = apply_rotary_pos_emb_ds(
q_for_hf2, k_for_hf2, cos_new, sin_new, position_ids, unsqueeze_dim=1
)
torch.testing.assert_close(q_rotated_hf2.transpose(1, 2), q_rotated_hf, rtol=rtol, atol=atol)
torch.testing.assert_close(k_rotated_hf2.transpose(1, 2), k_rotated_hf, rtol=rtol, atol=atol)

View File

@ -1,211 +0,0 @@
import pytest
import torch
from _graph_test_helpers import run_test
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
match_rope_v1,
match_rope_v2,
)
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
torch.manual_seed(0)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def _precompute_freqs_cis(seq_len: int, head_dim: int, rope_theta: float):
dtype = torch.float32
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
positions = torch.arange(seq_len, dtype=torch.float32)
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos, sin
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
):
# Reshape the inputs to pair the last dimension.
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
xq_out = torch.view_as_real(xq_complex * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_complex * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def _precompute_freqs_cis_v2(seq_len: int, head_dim: int, rope_theta: float):
"""
Compute the frequency tensor for the complex multiplication RoPE variant.
Returns a complex tensor of shape (seq_len, head_dim//2).
"""
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
)
positions = torch.arange(seq_len, dtype=torch.float32)
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
# Create a complex tensor from magnitude=1 and the computed angles.
freqs_cis = torch.polar(torch.ones_like(angles), angles)
return freqs_cis
class RotaryModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
max_seq_len: int,
num_heads: int,
num_kv_heads: int,
layout: str = "BNSD",
):
super().__init__()
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.layout = layout
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.linear_q(x)
k = self.linear_k(x)
batch, seq, _ = q.shape
q = q.view(batch, seq, self.num_heads, self.head_dim)
k = k.view(batch, seq, self.num_kv_heads, self.head_dim)
if self.layout == "BNSD":
q = q.permute(0, 2, 1, 3).contiguous() # [B, N, S, D]
k = k.permute(0, 2, 1, 3).contiguous()
unsqueeze_dim = 1
else: # BSND
unsqueeze_dim = 2
cos, sin = _precompute_freqs_cis(seq, self.head_dim, rope_theta=10000)
cos = cos.to(q.device).unsqueeze(0).expand(batch, -1, -1)
sin = sin.to(q.device).unsqueeze(0).expand(batch, -1, -1)
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim)
if self.layout == "BNSD":
# [B, N, S, D] -> [B, S, N*D]
q_embed = q_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
k_embed = k_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
else: # BSND
q_embed = q_embed.reshape(batch, seq, -1)
k_embed = k_embed.reshape(batch, seq, -1)
output = torch.cat([q_embed, k_embed], dim=-1)
return output.to(torch.float16)
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
class RotaryModelV2(torch.nn.Module):
def __init__(self, hidden_size: int, max_seq_len: int, num_heads: int, num_kv_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq, _ = x.shape
q = self.linear_q(x).view(batch, seq, self.num_heads, self.head_dim)
k = self.linear_k(x).view(batch, seq, self.num_kv_heads, self.head_dim)
freqs_cis = _precompute_freqs_cis_v2(seq, self.head_dim, rope_theta=10000)
freqs_cis = freqs_cis.to(x.device).unsqueeze(0).expand(batch, -1, -1)
q_embed, k_embed = apply_rotary_emb(q, k, freqs_cis)
q_embed = q_embed.reshape(batch, seq, -1)
k_embed = k_embed.reshape(batch, seq, -1)
output = torch.cat([q_embed, k_embed], dim=-1)
return output.to(torch.float16)
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
@pytest.mark.parametrize("layout", ["BNSD", "BSND"])
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
@torch.inference_mode()
def test_match_rope(layout, num_heads, num_kv_heads):
batch_size, seq_len = 8, 16
hidden_size = 512
max_position_embeddings = seq_len
model = RotaryModel(
hidden_size, max_position_embeddings, num_heads, num_kv_heads, layout=layout
).to("cuda", dtype=torch.float16)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
_ = run_test(
model,
x,
match_rope_v1,
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
@torch.inference_mode()
def test_match_rope_v2(num_heads, num_kv_heads):
batch_size, seq_len = 8, 16
hidden_size = 512
max_position_embeddings = seq_len
model = RotaryModelV2(hidden_size, max_position_embeddings, num_heads, num_kv_heads).to(
"cuda", dtype=torch.float16
)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
_ = run_test(
model,
x,
match_rope_v2,
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)

View File

@ -0,0 +1,429 @@
import pytest
import torch
from _graph_test_helpers import run_test
from _model_test_utils import (
apply_rotary_pos_emb_complex,
apply_rotary_pos_emb_ds,
apply_rotary_pos_emb_explicit,
)
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
match_complex_rope,
match_explicit_rope,
match_rope_layout,
optimize_rope,
)
from tensorrt_llm._torch.auto_deploy.utils.node_utils import extract_output_tuple, is_op
torch.manual_seed(0)
def _precompute_freqs_cis_explicit(seq_len: int, head_dim: int, rope_theta: float):
dtype = torch.float32
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
positions = torch.arange(seq_len, dtype=torch.float32)
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos, sin
def _precompute_freqs_cis_complex(seq_len: int, head_dim: int, rope_theta: float):
"""
Compute the frequency tensor for the complex multiplication RoPE variant.
Returns a complex tensor of shape (seq_len, head_dim//2).
"""
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
)
positions = torch.arange(seq_len, dtype=torch.float32)
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
# Create a complex tensor from magnitude=1 and the computed angles.
freqs_cis = torch.polar(torch.ones_like(angles), angles)
return freqs_cis
class RoPEModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
max_seq_len: int,
num_heads: int,
num_kv_heads: int,
variant: str = "explicit", # "explicit" or "complex"
mode: str = "match", # "match" or "optimize"
layout: str = "BNSD", # "BNSD" or "BSND"
):
super().__init__()
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.variant = variant
self.mode = mode
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.layout = layout
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, s, _ = x.shape
q = self.linear_q(x)
k = self.linear_k(x)
if self.variant == "explicit":
# reshape and permute if BNSD layout
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_kv_heads, self.head_dim)
if self.layout == "BNSD":
q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
unsq_dim = 1
else:
unsq_dim = 2
cos, sin = _precompute_freqs_cis_explicit(s, self.head_dim, rope_theta=10000)
cos = cos.to(x.device).unsqueeze(0).expand(b, -1, -1)
sin = sin.to(x.device).unsqueeze(0).expand(b, -1, -1)
if self.mode == "match":
q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim)
else: # optimize
q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin(
q, k, cos, sin, unsq_dim
)
# revert layout and flatten
if self.layout == "BNSD":
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
else:
q_out = q_out.reshape(b, s, -1)
k_out = k_out.reshape(b, s, -1)
else: # complex variant
q = q.view(b, s, self.num_heads, self.head_dim)
k = k.view(b, s, self.num_kv_heads, self.head_dim)
if self.layout == "BNSD":
q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
unsq_dim = 1
else:
unsq_dim = 2
freqs = _precompute_freqs_cis_complex(s, self.head_dim, rope_theta=10000)
freqs = freqs.to(x.device).unsqueeze(0).expand(b, -1, -1)
if self.mode == "match":
q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim)
else:
q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs(
q, k, freqs, unsq_dim
)
# revert layout and flatten
if self.layout == "BNSD":
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
else:
q_out = q_out.reshape(b, s, -1)
k_out = k_out.reshape(b, s, -1)
out = torch.cat([q_out, k_out], dim=-1)
return out.to(torch.float16) if self.mode == "match" else out
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
@pytest.mark.parametrize(
"transformation,variant,layout,batch_size,seq_len,num_heads,num_kv_heads,atol,rtol, target_layout",
[
("match", "explicit", "BNSD", 8, 16, 8, 8, 1e-2, 1e-2, None),
("match", "explicit", "BSND", 8, 16, 8, 4, 1e-2, 1e-2, None),
("match", "complex", "BNSD", 8, 16, 8, 8, 1e-3, 1e-3, None),
("match", "complex", "BSND", 8, 16, 8, 4, 1e-3, 1e-3, None),
("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BNSD"),
("match_layout", "complex", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
("match_layout", "complex", "BSND", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"),
pytest.param(
"optimize",
"explicit",
"BNSD",
4,
12,
8,
8,
1e-3,
1e-3,
None,
marks=pytest.mark.xfail(
reason="flashinfer op does not support BNSD layout", strict=True
),
),
("optimize", "explicit", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None),
pytest.param(
"optimize",
"complex",
"BNSD",
4,
12,
8,
8,
1e-3,
1e-3,
None,
marks=pytest.mark.xfail(
reason="flashinfer op does not support BNSD layout", strict=True
),
),
("optimize", "complex", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None),
],
)
@torch.inference_mode()
def test_rope_variants(
transformation,
variant,
layout,
batch_size,
seq_len,
num_heads,
num_kv_heads,
atol,
rtol,
target_layout,
):
hidden_size = 512
model = RoPEModel(
hidden_size,
seq_len,
num_heads,
num_kv_heads,
variant=variant,
mode=transformation,
layout=layout or "BNSD",
).to("cuda", torch.float16)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dyn = model.get_dynamic_shapes()
if transformation == "match":
fn = match_explicit_rope if variant == "explicit" else match_complex_rope
check_op = (
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin
if variant == "explicit"
else torch.ops.rope.torch_apply_rope_with_complex_freqs
)
def checker(gm):
return any(is_op(n, check_op) for n in gm.graph.nodes)
elif transformation == "match_layout":
fn = match_rope_layout
def checker(gm):
for n in gm.graph.nodes:
if is_op(
n,
{
torch.ops.rope.torch_apply_rope_with_explicit_cos_sin,
torch.ops.rope.torch_apply_rope_with_complex_freqs,
},
):
q_arg, k_arg, *rest = n.args
if not (
is_op(q_arg, torch.ops.aten.contiguous)
and is_op(k_arg, torch.ops.aten.contiguous)
):
matched = False
break
old_q, old_k = extract_output_tuple(n, 2)
if old_q is None or old_k is None:
matched = False
break
q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users)
k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users)
matched = q_transposed and k_transposed
return matched if layout != target_layout else not matched
else:
fn = optimize_rope
def checker(gm):
return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes)
if target_layout:
_ = run_test(
model,
x,
fn,
checker,
lambda n: n,
atol, # atol
rtol, # rtol
True, # test_load_hook
True, # strict_loading
dyn, # dynamic_shapes
target_layout,
)
else:
_ = run_test(
model,
x,
fn,
checker,
lambda n: n,
atol, # atol
rtol, # rtol
True, # test_load_hook
True, # strict_loading
dyn, # dynamic_shapes
)
class DSRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(max_position_embeddings, device=device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x, seq_len=None):
# returns [seq_len, head_dim] cos & sin
return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)
class DSModel(torch.nn.Module):
def __init__(self, hidden_size, max_seq, n_head, n_kv, layout="BNSD", mode: str = "match"):
super().__init__()
self.hdim = hidden_size // n_head
self.layout = layout
self.q_lin = torch.nn.Linear(hidden_size, n_head * self.hdim)
self.k_lin = torch.nn.Linear(hidden_size, n_kv * self.hdim)
self.rotary = DSRotaryEmbedding(self.hdim, max_seq, base=10000, device="cuda")
self.mode = mode # "match" or "optimize"
def forward(self, x):
b, s, _ = x.shape
q = self.q_lin(x).view(b, s, -1, self.hdim)
k = self.k_lin(x).view(b, s, -1, self.hdim)
if self.layout == "BNSD":
# to [B, N, S, D]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
unsq_dim = 1
else:
unsq_dim = 2
cos, sin = self.rotary(x, seq_len=s)
# build position_ids [B, S]
pos_ids = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s)
if self.mode == "match":
q_out, k_out = apply_rotary_pos_emb_ds(q, k, cos, sin, pos_ids, unsqueeze_dim=unsq_dim)
else:
cos = cos[pos_ids]
sin = sin[pos_ids]
q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving(
q, k, cos, sin, unsq_dim
)
if self.layout == "BNSD":
# back to [B, S, N*D]
q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1)
k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1)
else:
q_out = q_out.reshape(b, s, -1)
k_out = k_out.reshape(b, s, -1)
return torch.cat([q_out, k_out], dim=-1)
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
@pytest.mark.parametrize(
"layout,num_heads,num_kv_heads,mode, target_layout",
[
("BNSD", 8, 8, "match", None),
("BSND", 8, 4, "match", None),
("BNSD", 8, 8, "match_layout", "BNSD"),
("BSND", 8, 4, "match_layout", "BNSD"),
("BSND", 8, 4, "match_layout", "BSND"),
("BNSD", 8, 4, "match_layout", "BSND"),
],
)
@torch.inference_mode()
def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target_layout):
batch, seq, hid = 4, 12, 512
model = DSModel(hid, 16, num_heads, num_kv_heads, layout=layout, mode=mode)
model = model.to("cuda", torch.float16)
x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
if mode == "match":
transform = match_explicit_rope
def checker(gm):
return any(
is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving)
for n in gm.graph.nodes
)
else: # mode == "match_layout"
transform = match_rope_layout
def checker(gm):
for n in gm.graph.nodes:
if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving):
q_arg, k_arg, *rest = n.args
if not (
is_op(q_arg, torch.ops.aten.contiguous)
and is_op(k_arg, torch.ops.aten.contiguous)
):
matched = False
break
old_q, old_k = extract_output_tuple(n, 2)
if old_q is None or old_k is None:
matched = False
break
q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users)
k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users)
matched = q_transposed and k_transposed
return matched if layout != target_layout else not matched
if target_layout:
_ = run_test(
model,
x,
transform,
checker,
lambda num_p: num_p,
1e-3, # atol
1e-3, # rtol
True, # test_load_hook
True, # strict_loading
dynamic_shapes, # dynamic_shapes
target_layout,
)
else:
_ = run_test(
model,
x,
transform,
checker,
lambda num_p: num_p,
1e-3, # atol
1e-3, # rtol
True, # test_load_hook
True, # strict_loading
dynamic_shapes, # dynamic_shapes
)