feat:[AutoDeploy] Enhance RoPE support (#3115)

* add test to map flashinfer rope op with triton custom rope ops and pytorch rope in fused_mha

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add rope matcher and unit tests

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* capture cos and sin from graph

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* revert fuse_mha op change

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor update to address comment and remove redundant unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* move view and transpose into graph nodes and update unit test to test custom op directly

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* move view into custom op, update bfs with bound, update custom op return type to be half precision

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* custom op update to support 3D input

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* handle bnsd and bsnd format, update tests, handle 3D cos/sin input to the custom op

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add llama4 rope test, update custom op with is_neox flag

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* add llama4 style rope to matcher and update unit test

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* separate into two transformations

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* fix when num_head != num_kv_head; add support for cached position_ids and cos_sin_cache in graph; update unit tests

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor update, cache locally and propagate meta info of qk nodes

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor: fix cos_sin_cache not float

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

* minor: move cache into matcher

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>

---------

Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Fridah-nv 2025-04-11 08:51:24 -07:00 committed by GitHub
parent 11b0091863
commit ec723fa993
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 959 additions and 20 deletions

View File

@ -2,6 +2,7 @@
from .dist import *
from .flashinfer_attention import *
from .flashinfer_rope import *
from .fused_moe import *
from .linear import *
from .mla import *

View File

@ -0,0 +1,65 @@
from typing import Tuple
import flashinfer
import torch
@torch.library.custom_op("rope::flashinfer", mutates_args=())
def apply_rope_with_input_pos_flashinfer(
q: torch.Tensor,
k: torch.Tensor,
position_ids: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applies rotary positional embeddings (RoPE) to query and key tensors using the FlashInfer kernel.
This updated version expects precomputed positional IDs and a fused cosine-sine cache.
Inputs:
- q, k (torch.Tensor):
Tensors of shape [batch, seq_len, n_head, head_dim] (or a 3D variant)
in half precision. Note: head_dim must be a multiple of 64.
- position_ids (torch.Tensor):
Precomputed tensor of positional indices; it is shared across calls in the graph.
- cos_sin_cache (torch.Tensor):
Precomputed fused tensor created by concatenating the first half of the cosine and sine
components derived from the inv_freq.
- is_neox (bool):
Flag to indicate whether to invoke the FlashInfer kernel in Neox mode.
Returns:
A tuple of:
- Rotated query tensor of the same shape and half precision as input.
- Rotated key tensor of the same shape and half precision as input.
"""
q_shape = q.shape
k_shape = k.shape
batch_size, seq_len = q_shape[:2]
head_dim = cos_sin_cache.shape[-1]
q_flat = q.view(batch_size * seq_len, -1)
k_flat = k.view(batch_size * seq_len, -1)
position_ids = position_ids.to(q.device)
print("cos_sin_cache.shape", cos_sin_cache.shape)
query_rotated_flash, key_rotated_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
position_ids, q_flat, k_flat, head_dim, cos_sin_cache, is_neox=is_neox
)
query_rotated_flash = query_rotated_flash.view(q_shape)
key_rotated_flash = key_rotated_flash.view(k_shape)
return query_rotated_flash, key_rotated_flash
@apply_rope_with_input_pos_flashinfer.register_fake
def apply_rope_with_input_pos_flashinfer_fake(
q: torch.Tensor,
k: torch.Tensor,
position_ids: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
return q, k

View File

@ -20,7 +20,6 @@ def apply_rotary_pos_emb(
):
"""
Apply rotary positional embeddings to query and key tensors.
Args:
q: Query tensor of shape [batch, n_heads, seq_len, head_dim]
k: Key tensor of shape [batch, n_kv_heads, seq_len, head_dim]
@ -28,7 +27,6 @@ def apply_rotary_pos_emb(
head_dim: Dimension of each head
rope_theta: Base value for RoPE (default 10000.0)
rope_scale: Scaling factor for positions (default 1.0)
Returns:
Tuple of transformed query and key tensors
"""
@ -108,7 +106,6 @@ def fused_mha(
rope_scale: Optional[float] = None,
) -> torch.Tensor:
"""Fused MHA+Rope that takes raw input from q, k, v GEMMs.
Rope is performed according to the specified rope configuration. No support for caching.
"""
# b, s info

View File

@ -7,6 +7,7 @@ from .fused_moe import *
from .fusion import *
from .kvcache import *
from .quantization import *
from .rope import *
from .sharding import *
try:

View File

@ -1,7 +1,7 @@
"""Pattern matcher to detect MHA pattern and replace with simple fused_mha op."""
from collections import defaultdict
from typing import Callable, Dict, List, Optional
from typing import Dict, List, Optional
import torch
from torch._subclasses import FakeTensor
@ -9,7 +9,7 @@ from torch.fx import Graph, GraphModule, Node
from ...models.factory import PositionalEmbeddingConfig
from ...utils.logger import ad_logger
from ...utils.node_utils import is_dist_op, is_linear_op, is_op
from ...utils.node_utils import bfs, is_dist_op, is_linear_op, is_op
from .._graph import canonicalize_graph
@ -20,16 +20,6 @@ def _is_dist_lin_op(node: Node, exclude: Optional[List[Node]] = None) -> bool:
)
def _bfs(node: Node, target: Callable, attr_next: str = "users") -> Node:
queue = [node]
while queue:
cur_node = queue.pop(0)
if target(cur_node):
return cur_node
queue.extend(getattr(cur_node, attr_next))
raise RuntimeError(f"Could not find node with target condition {target}.")
def identify_and_fuse_mha(
egm: GraphModule, pos_embd_config: PositionalEmbeddingConfig
) -> GraphModule:
@ -66,9 +56,9 @@ def identify_and_fuse_mha(
# from the sdpa node, identify q, k, v, and out GEMMs via BFS
for arg in mha_node.args[:3]:
mha_gemms[mha_node].append(
_bfs(arg, lambda n: _is_dist_lin_op(n, mha_gemms[mha_node]), "all_input_nodes")
bfs(arg, lambda n: _is_dist_lin_op(n, mha_gemms[mha_node]), "all_input_nodes")
)
mha_gemms[mha_node].append(_bfs(mha_node, _is_dist_lin_op, "users"))
mha_gemms[mha_node].append(bfs(mha_node, _is_dist_lin_op, "users"))
# get fake q tensor that is an MHA input node to retrieve head_dim
q_fake: FakeTensor = mha_node.args[0].meta["val"]

View File

@ -0,0 +1,508 @@
import operator
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Optional
import torch
from torch.fx import GraphModule, Node
from ...utils.logger import ad_logger
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op
from .._graph import canonicalize_graph
def match_rope_v1(gm: GraphModule) -> GraphModule:
"""
Identify and replace legacy RoPE subgraphs (explicit cos/sin multiplication pattern):
output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
If exactly two such branches (query and key) are detected within each region, they're replaced
by a call to `torch.ops.rope.flashinfer`.
"""
graph = gm.graph
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
rope_position_ids_cache: Dict[str, Node] = {}
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
matches = []
node = start_boundary
while node != end_boundary:
if is_op(node, torch.ops.aten.add):
match_info = _match_rotary_subpattern_V1(node)
if match_info:
matches.append(match_info)
node = node.next
if not matches:
continue
if len(matches) != 2:
raise RuntimeError(
f"Expected exactly 2 legacy RoPE branches between {start_boundary} and {end_boundary}, "
f"found {len(matches)}."
)
# Assume the first matched branch is query (q), second is key (k).
# This assumption is based on the default ordering in the exported graph,
# since node naming conventions don't reliably indicate q/k branches.
q_match, k_match = matches
_process_rope_v1(
graph,
q_match,
k_match,
start_boundary,
rope_flash_cache,
rope_position_ids_cache,
)
gm = canonicalize_graph(gm)
return gm
def match_rope_v2(gm: GraphModule) -> GraphModule:
"""
Identify and replace RoPE subgraphs using complex multiplication pattern:
output = type_as(flatten(view_as_real(mul(view_as_complex(reshape(to_dtype(x))), unsqueeze(freqs_cis, 2)))), x)
If exactly two such branches (query and key) are detected within each region, they're replaced
by a call to `torch.ops.rope.flashinfer`.
"""
graph = gm.graph
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
rope_position_ids_cache: Dict[str, Node] = {}
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
matches = []
node = start_boundary
while node != end_boundary:
if is_op(node, torch.ops.aten.type_as):
match_info = _match_rotary_subpattern_V2(node)
if match_info:
matches.append(match_info)
node = node.next
if not matches:
continue
if len(matches) != 2:
raise RuntimeError(
f"Expected exactly 2 complex RoPE branches between {start_boundary} and {end_boundary}, "
f"found {len(matches)}."
)
# Assume the first matched branch is query (q), second is key (k).
# This assumption is based on the default ordering in the exported graph,
# since node naming conventions don't reliably indicate q/k branches.
q_match, k_match = matches
_process_rope_v2(
graph,
q_match,
k_match,
rope_flash_cache,
rope_position_ids_cache,
)
gm = canonicalize_graph(gm)
return gm
def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
"""
Given an aten.add.Tensor node that is expected to compute:
output = (raw_input * unsqueeze(cos)) + (rotate_half(raw_input) * unsqueeze(sin))
where rotate_half is implemented as:
rotate_half(x) = cat([ -slice(x, second_half), slice(x, first_half) ], dim=-1)
this function inspects the structure of add_node and returns a dictionary with:
- "raw_input": the original q/k tensor,
- "unsqueeze_cos": the unsqueeze node feeding the raw multiplication,
- "unsqueeze_sin": the unsqueeze node feeding the rotated multiplication,
- "add_node": the addition node itself.
Returns None if the pattern does not match.
"""
# Check that add_node is an add operation with two inputs.
if not is_op(add_node, torch.ops.aten.add):
return None
if not (len(add_node.args) == 2):
return None
mul1, mul2 = add_node.args
# Both inputs to the add should be multiplications.
if not is_op(mul1, torch.ops.aten.mul):
return None
if not is_op(mul2, torch.ops.aten.mul):
return None
# One branch should be the raw branch and the other the rotated branch.
# We decide by checking if one multiplications first argument is a cat (i.e. the rotate_half result).
if is_op(mul1.args[0], torch.ops.aten.cat):
mul_rot = mul1
mul_raw = mul2
elif is_op(mul2.args[0], torch.ops.aten.cat):
mul_rot = mul2
mul_raw = mul1
else:
return None
# Verify that both multiplications have an unsqueeze as their second argument.
unsqueeze_cos = mul_raw.args[1]
unsqueeze_sin = mul_rot.args[1]
if not is_op(unsqueeze_cos, torch.ops.aten.unsqueeze):
return None
if not is_op(unsqueeze_sin, torch.ops.aten.unsqueeze):
return None
# Check that the rotated branch is a cat of two tensors along -1.
cat_node = mul_rot.args[0]
if not is_op(cat_node, torch.ops.aten.cat):
return None
# Expecting two inputs in a list/tuple.
cat_inputs = cat_node.args[0]
if not (isinstance(cat_inputs, (list, tuple)) and len(cat_inputs) == 2):
return None
# One of the two inputs should be a negation of a slice, the other should be a slice.
first_item, second_item = cat_inputs
if not is_op(first_item, torch.ops.aten.neg):
return None
if not is_op(second_item, torch.ops.aten.slice):
return None
# The negation node should wrap a slice.
neg_node = first_item
if not (len(neg_node.args) >= 1 and is_op(neg_node.args[0], torch.ops.aten.slice)):
return None
# For simplicity, require that the two slice operations (the one inside neg and the one used directly)
# are applied on the same original tensor. This original tensor is the one being rotated.
slice_in_neg = neg_node.args[0]
if slice_in_neg.args[0] != second_item.args[0]:
return None
# Finally, the raw branch should multiply the original tensor (i.e. q or k) by unsqueeze_cos.
raw_input = mul_raw.args[0]
# We also expect that the tensor being sliced (and negated) is the same as raw_input.
if raw_input != slice_in_neg.args[0]:
return None
return {
"raw_input": raw_input,
"unsqueeze_cos": unsqueeze_cos,
"unsqueeze_sin": unsqueeze_sin,
"add_node": add_node,
}
def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]:
"""
Given a type_as node, this function inspects the graph
structure and returns a dictionary with:
- "input": the original xq (or xk) tensor,
- "inv_freq": the freqs_cis tensor (before unsqueeze),
- "out": the type_as node corresponding to the branch output.
Expected branch structure for each output:
x_out = type_as( flatten( view_as_real( view_as_complex(reshape(to_dtype(x))) * unsqueeze(freqs_cis) ) ) )
Returns None if the structure does not match.
"""
if not is_op(type_as_node, torch.ops.aten.type_as):
return None
# The type_as node should have at least one argument: its first argument is the flatten op.
if not (len(type_as_node.args) >= 1):
return None
flatten_node = type_as_node.args[0]
if not is_op(flatten_node, torch.ops.aten.flatten):
return None
# The input of the flatten op should be a view_as_real op.
if not (len(flatten_node.args) >= 1):
return None
view_as_real_node = flatten_node.args[0]
if not is_op(view_as_real_node, torch.ops.aten.view_as_real):
return None
# The input of view_as_real should be a multiplication.
if not (len(view_as_real_node.args) >= 1):
return None
mul_node = view_as_real_node.args[0]
if not is_op(mul_node, torch.ops.aten.mul):
return None
if len(mul_node.args) != 2:
return None
# In the multiplication, one operand should be an unsqueeze of freqs_cis and
# the other operand is the output of view_as_complex.
if is_op(mul_node.args[0], torch.ops.aten.unsqueeze):
unsqueeze_node = mul_node.args[0]
vc_node = mul_node.args[1]
elif is_op(mul_node.args[1], torch.ops.aten.unsqueeze):
unsqueeze_node = mul_node.args[1]
vc_node = mul_node.args[0]
else:
return None
# Verify that the unsqueeze is performed along dimension 2.
if not (len(unsqueeze_node.args) >= 2 and unsqueeze_node.args[1] == 2):
return None
inv_freq_candidate = unsqueeze_node.args[0]
# Match the view_as_complex branch.
if not is_op(vc_node, torch.ops.aten.view_as_complex):
return None
if not (len(vc_node.args) >= 1):
return None
reshape_node = vc_node.args[0]
if not is_op(reshape_node, torch.ops.aten.reshape):
return None
# The reshape op should get its input from a to(dtype) conversion.
if not (len(reshape_node.args) >= 1):
return None
to_node = reshape_node.args[0]
if not is_op(to_node, torch.ops.aten.to):
return None
if not (len(to_node.args) >= 1):
return None
input_tensor = to_node.args[0]
return {
"input": input_tensor,
"inv_freq": inv_freq_candidate,
"out": type_as_node,
}
def _process_rope_v1(
graph: GraphModule,
q_match: Dict[str, Node],
k_match: Dict[str, Node],
start_boundary: Node,
rope_flash_cache: DefaultDict[Any, Optional[Node]],
rope_position_ids_cache: Dict[str, Node],
) -> None:
"""
Process a region that matched the legacy RoPE pattern (v1).
Inserts the custom op (flashinfer) and replaces the original add nodes.
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
"""
q_node = q_match["raw_input"]
k_node = k_match["raw_input"]
cos_node = q_match["unsqueeze_cos"].args[0]
sin_node = q_match["unsqueeze_sin"].args[0]
# Sanity-check: ensure cos/sin nodes trace back to aten.cos/aten.sin.
bfs(
cos_node,
lambda n: is_op(n, torch.ops.aten.cos),
attr_next="all_input_nodes",
boundary=start_boundary,
)
bfs(
sin_node,
lambda n: is_op(n, torch.ops.aten.sin),
attr_next="all_input_nodes",
boundary=start_boundary,
)
# Infer input layout; default to [b, n, s, d] if inference fails.
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
need_transpose = isinstance(q_fake.shape[1], int)
ad_logger.debug(
f"Inferred RoPE input layout: [{'[b, n, s, d]' if need_transpose else '[b, s, n, d]'}]"
)
# Additional sanity check for the third dimension
if need_transpose:
if not isinstance(q_fake.shape[2], torch.SymInt):
ad_logger.warning(
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
)
need_transpose = True
else:
if not isinstance(q_fake.shape[1], torch.SymInt):
ad_logger.warning(
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
)
need_transpose = True
else:
ad_logger.warning("Unable to infer layout of q node. Defaulting to [b, n, s, d].")
need_transpose = True
with graph.inserting_before(q_match["add_node"]):
if need_transpose:
q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2))
k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2))
q_for_op_contig = graph.call_method("contiguous", (q_for_op,))
k_for_op_contig = graph.call_method("contiguous", (k_for_op,))
else:
q_for_op_contig, k_for_op_contig = q_node, k_node
head_dim = cos_node.meta["val"].shape[-1]
half_head_dim = head_dim // 2
cache = rope_flash_cache
cache_key = (cos_node, sin_node)
if cache_key in cache:
fused_cos_sin = cache[cache_key]
else:
cos_prefix = graph.call_function(
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
)
sin_prefix = graph.call_function(
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
)
fused_cos_sin = graph.call_function(
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
)
fused_cos_sin = graph.call_function(operator.getitem, args=(fused_cos_sin, 0))
fused_cos_sin = graph.call_method("to", (fused_cos_sin, torch.float32))
cache[cache_key] = fused_cos_sin
position_ids = _get_position_ids(
graph,
q_for_op_contig,
batch_dim=0,
seq_dim=1,
rope_position_ids_cache=rope_position_ids_cache,
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_for_op_contig, k_for_op_contig, position_ids, fused_cos_sin, True),
)
with graph.inserting_after(flash_node):
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
if need_transpose:
with graph.inserting_after(raw_q):
new_q = graph.call_function(torch.ops.aten.transpose, args=(raw_q, 1, 2))
with graph.inserting_after(raw_k):
new_k = graph.call_function(torch.ops.aten.transpose, args=(raw_k, 1, 2))
else:
new_q, new_k = raw_q, raw_k
new_q.meta["val"] = q_match["add_node"].meta.get("val", None)
new_k.meta["val"] = k_match["add_node"].meta.get("val", None)
q_match["add_node"].replace_all_uses_with(new_q)
k_match["add_node"].replace_all_uses_with(new_k)
def _process_rope_v2(
graph: GraphModule,
q_match: Dict[str, Node],
k_match: Dict[str, Node],
rope_flash_cache: DefaultDict[Any, Optional[Node]],
rope_position_ids_cache: Dict[str, Node],
) -> None:
"""
Process a region that matched the complex-multiplication RoPE pattern (v2).
Inserts the custom op (flashinfer) after extracting frequency information,
and replaces the original type_as nodes.
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
and reuses those nodes when possible.
"""
q_node = q_match["input"]
k_node = k_match["input"]
inv_freq_node = q_match["inv_freq"]
if inv_freq_node != k_match["inv_freq"]:
raise RuntimeError("Mismatch of freqs_cis (inv_freq) between branches.")
# Sanity check that input layout is BSND (no transpose needed).
q_fake = q_node.meta.get("val", None)
if q_fake is not None and len(q_fake.shape) > 2:
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
ad_logger.warning(
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
)
else:
ad_logger.warning(
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
)
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
cache = rope_flash_cache
if inv_freq_node in cache:
cos_sin_flash = cache[inv_freq_node]
else:
# Compute the fused cosine/sine cache.
with graph.inserting_after(inv_freq_node):
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
with graph.inserting_after(real_part):
cos_sin_flash_3d = graph.call_function(
torch.ops.aten.cat, args=((real_part, imag_part), -1)
)
with graph.inserting_after(cos_sin_flash_3d):
cos_sin_flash = graph.call_function(operator.getitem, args=(cos_sin_flash_3d, 0))
with graph.inserting_after(cos_sin_flash):
cos_sin_flash = graph.call_method("to", (cos_sin_flash, torch.float32))
cache[inv_freq_node] = cos_sin_flash
with graph.inserting_before(q_match["out"]):
position_ids = _get_position_ids(
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=rope_position_ids_cache
)
flash_node = graph.call_function(
torch.ops.rope.flashinfer,
args=(q_node, k_node, position_ids, cos_sin_flash, False),
)
with graph.inserting_after(flash_node):
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
raw_q.meta["val"] = q_match["out"].meta.get("val", None)
raw_k.meta["val"] = k_match["out"].meta.get("val", None)
q_match["out"].replace_all_uses_with(raw_q)
k_match["out"].replace_all_uses_with(raw_k)
def _get_position_ids(
graph: GraphModule,
q_node: Node,
batch_dim: int = 0,
seq_dim: int = 1,
rope_position_ids_cache: Dict[str, Node] = None,
) -> Node:
"""
Retrieves the cached position_ids from the graph if available, or computes and caches them.
It uses the symbolic batch and sequence sizes from q_node with the provided dimension indices.
"""
if rope_position_ids_cache is None:
rope_position_ids_cache = {}
if "position_ids" in rope_position_ids_cache:
return rope_position_ids_cache["position_ids"]
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, batch_dim))
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, seq_dim))
# Retrieve device information, ensuring it is a torch.device.
device = q_node.meta.get("device", "cpu")
if isinstance(device, str):
device = torch.device(device)
# Build positions: arange(sym_seq) -> view -> expand -> flatten.
positions_node = graph.call_function(
torch.ops.aten.arange,
args=(sym_seq,),
kwargs={"dtype": torch.float32, "device": device, "pin_memory": False},
)
positions_node = graph.call_function(torch.ops.aten.view, args=(positions_node, (1, -1)))
positions_node = graph.call_function(
torch.ops.aten.expand, args=(positions_node, (sym_batch, -1))
)
position_ids = graph.call_function(torch.ops.aten.flatten, args=(positions_node,))
rope_position_ids_cache["position_ids"] = position_ids
return position_ids

View File

@ -1,7 +1,7 @@
"""Common utils for torch fx graph transformation."""
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
from typing import Callable, Iterable, List, Optional, Tuple, Union
import torch
from torch._ops import OpOverload, OpOverloadPacket
@ -287,3 +287,23 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]:
boundary_nodes.append(output_node)
return boundary_nodes
def bfs(
node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None
) -> Node:
queue = [node]
visited = set()
while queue:
cur_node = queue.pop(0)
if boundary is not None and cur_node == boundary:
continue # Skip the boundary node.
if target(cur_node):
return cur_node
for next_node in getattr(cur_node, attr_next):
if boundary is not None and next_node == boundary:
continue # Do not expand past the boundary.
if next_node not in visited:
visited.add(next_node)
queue.append(next_node)
raise RuntimeError(f"Could not find node with target condition {target}.")

View File

@ -1,5 +1,5 @@
import copy
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional
import numpy as np
import torch
@ -34,6 +34,7 @@ def run_test(
rtol: float = 1e-3,
test_load_hook: bool = True,
strict_loading: bool = True,
dynamic_shapes: Dict = None,
*args, # Additional arguments for transform
) -> GraphModule:
# run model once
@ -44,7 +45,7 @@ def run_test(
print(num_params_model)
# export + check (we clone the state dict to have a bit more freedom in testing below)
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
print(gm)
y_gm = gm(x)
num_params_gm = count_parameters(gm)

View File

@ -0,0 +1,145 @@
from typing import Tuple
import flashinfer
import pytest
import torch
import tensorrt_llm._torch.auto_deploy # noqa: F401
torch.manual_seed(0)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_and_custom_rope_ops(dtype, atol, rtol, head_dim):
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
# Prepare rotary embedding values.
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
# For HF and the custom op: duplicated layout [seq_len, head_dim].
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
# Direct FlashInfer kernel call.
query_flat = query.view(batch * seq_len, n_head * head_dim)
key_flat = key.view(batch * seq_len, n_head * head_dim)
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
)
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
# HF implementation using apply_rotary_pos_emb.
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
q_for_hf = query.transpose(1, 2).clone()
k_for_hf = key.transpose(1, 2).clone()
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
q_hf, k_hf = apply_rotary_pos_emb(q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1)
# Convert outputs to [batch, seq_len, n_head, head_dim]
q_hf = q_hf.transpose(1, 2).to(dtype)
k_hf = k_hf.transpose(1, 2).to(dtype)
# Custom op call
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, True)
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
# Version 2: complex multiplication approach
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-5, 1e-5),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
)
def test_flashinfer_complex_rotary(dtype, atol, rtol, head_dim):
device = "cuda"
batch = 2
seq_len = 4
n_head = 3
inv_freq = 1.0 / (
10000
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
)
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
out_q_v2, out_k_v2 = apply_rotary_emb(query, key, freqs_cis)
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
# q/k of llama4 rope is interleaved
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, False)
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)

View File

@ -0,0 +1,211 @@
import pytest
import torch
from _graph_test_helpers import run_test
from torch.export import Dim
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
match_rope_v1,
match_rope_v2,
)
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
torch.manual_seed(0)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def _precompute_freqs_cis(seq_len: int, head_dim: int, rope_theta: float):
dtype = torch.float32
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
positions = torch.arange(seq_len, dtype=torch.float32)
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype)
sin = emb.sin().to(dtype)
return cos, sin
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
):
# Reshape the inputs to pair the last dimension.
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
xq_out = torch.view_as_real(xq_complex * freqs_cis[:, :, None, :]).flatten(3)
xk_out = torch.view_as_real(xk_complex * freqs_cis[:, :, None, :]).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def _precompute_freqs_cis_v2(seq_len: int, head_dim: int, rope_theta: float):
"""
Compute the frequency tensor for the complex multiplication RoPE variant.
Returns a complex tensor of shape (seq_len, head_dim//2).
"""
inv_freq = 1.0 / (
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
)
positions = torch.arange(seq_len, dtype=torch.float32)
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
# Create a complex tensor from magnitude=1 and the computed angles.
freqs_cis = torch.polar(torch.ones_like(angles), angles)
return freqs_cis
class RotaryModel(torch.nn.Module):
def __init__(
self,
hidden_size: int,
max_seq_len: int,
num_heads: int,
num_kv_heads: int,
layout: str = "BNSD",
):
super().__init__()
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.layout = layout
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.linear_q(x)
k = self.linear_k(x)
batch, seq, _ = q.shape
q = q.view(batch, seq, self.num_heads, self.head_dim)
k = k.view(batch, seq, self.num_kv_heads, self.head_dim)
if self.layout == "BNSD":
q = q.permute(0, 2, 1, 3).contiguous() # [B, N, S, D]
k = k.permute(0, 2, 1, 3).contiguous()
unsqueeze_dim = 1
else: # BSND
unsqueeze_dim = 2
cos, sin = _precompute_freqs_cis(seq, self.head_dim, rope_theta=10000)
cos = cos.to(q.device).unsqueeze(0).expand(batch, -1, -1)
sin = sin.to(q.device).unsqueeze(0).expand(batch, -1, -1)
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim)
if self.layout == "BNSD":
# [B, N, S, D] -> [B, S, N*D]
q_embed = q_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
k_embed = k_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
else: # BSND
q_embed = q_embed.reshape(batch, seq, -1)
k_embed = k_embed.reshape(batch, seq, -1)
output = torch.cat([q_embed, k_embed], dim=-1)
return output.to(torch.float16)
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
class RotaryModelV2(torch.nn.Module):
def __init__(self, hidden_size: int, max_seq_len: int, num_heads: int, num_kv_heads: int):
super().__init__()
self.hidden_size = hidden_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = hidden_size // num_heads
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch, seq, _ = x.shape
q = self.linear_q(x).view(batch, seq, self.num_heads, self.head_dim)
k = self.linear_k(x).view(batch, seq, self.num_kv_heads, self.head_dim)
freqs_cis = _precompute_freqs_cis_v2(seq, self.head_dim, rope_theta=10000)
freqs_cis = freqs_cis.to(x.device).unsqueeze(0).expand(batch, -1, -1)
q_embed, k_embed = apply_rotary_emb(q, k, freqs_cis)
q_embed = q_embed.reshape(batch, seq, -1)
k_embed = k_embed.reshape(batch, seq, -1)
output = torch.cat([q_embed, k_embed], dim=-1)
return output.to(torch.float16)
def get_dynamic_shapes(self):
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
@pytest.mark.parametrize("layout", ["BNSD", "BSND"])
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
@torch.inference_mode()
def test_match_rope(layout, num_heads, num_kv_heads):
batch_size, seq_len = 8, 16
hidden_size = 512
max_position_embeddings = seq_len
model = RotaryModel(
hidden_size, max_position_embeddings, num_heads, num_kv_heads, layout=layout
).to("cuda", dtype=torch.float16)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
_ = run_test(
model,
x,
match_rope_v1,
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
@torch.inference_mode()
def test_match_rope_v2(num_heads, num_kv_heads):
batch_size, seq_len = 8, 16
hidden_size = 512
max_position_embeddings = seq_len
model = RotaryModelV2(hidden_size, max_position_embeddings, num_heads, num_kv_heads).to(
"cuda", dtype=torch.float16
)
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
dynamic_shapes = model.get_dynamic_shapes()
_ = run_test(
model,
x,
match_rope_v2,
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
lambda num_p_og: num_p_og,
atol=1e-3,
rtol=1e-3,
test_load_hook=True,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
)