mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat:[AutoDeploy] Enhance RoPE support (#3115)
* add test to map flashinfer rope op with triton custom rope ops and pytorch rope in fused_mha Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * add rope matcher and unit tests Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * capture cos and sin from graph Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * revert fuse_mha op change Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor update to address comment and remove redundant unit test Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * move view and transpose into graph nodes and update unit test to test custom op directly Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * move view into custom op, update bfs with bound, update custom op return type to be half precision Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * custom op update to support 3D input Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * handle bnsd and bsnd format, update tests, handle 3D cos/sin input to the custom op Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * add llama4 rope test, update custom op with is_neox flag Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * add llama4 style rope to matcher and update unit test Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * separate into two transformations Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * fix when num_head != num_kv_head; add support for cached position_ids and cos_sin_cache in graph; update unit tests Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor update, cache locally and propagate meta info of qk nodes Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor: fix cos_sin_cache not float Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> * minor: move cache into matcher Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> --------- Signed-off-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
11b0091863
commit
ec723fa993
@ -2,6 +2,7 @@
|
||||
|
||||
from .dist import *
|
||||
from .flashinfer_attention import *
|
||||
from .flashinfer_rope import *
|
||||
from .fused_moe import *
|
||||
from .linear import *
|
||||
from .mla import *
|
||||
|
||||
@ -0,0 +1,65 @@
|
||||
from typing import Tuple
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
|
||||
@torch.library.custom_op("rope::flashinfer", mutates_args=())
|
||||
def apply_rope_with_input_pos_flashinfer(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Applies rotary positional embeddings (RoPE) to query and key tensors using the FlashInfer kernel.
|
||||
This updated version expects precomputed positional IDs and a fused cosine-sine cache.
|
||||
|
||||
Inputs:
|
||||
- q, k (torch.Tensor):
|
||||
Tensors of shape [batch, seq_len, n_head, head_dim] (or a 3D variant)
|
||||
in half precision. Note: head_dim must be a multiple of 64.
|
||||
- position_ids (torch.Tensor):
|
||||
Precomputed tensor of positional indices; it is shared across calls in the graph.
|
||||
- cos_sin_cache (torch.Tensor):
|
||||
Precomputed fused tensor created by concatenating the first half of the cosine and sine
|
||||
components derived from the inv_freq.
|
||||
- is_neox (bool):
|
||||
Flag to indicate whether to invoke the FlashInfer kernel in Neox mode.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
- Rotated query tensor of the same shape and half precision as input.
|
||||
- Rotated key tensor of the same shape and half precision as input.
|
||||
"""
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
batch_size, seq_len = q_shape[:2]
|
||||
|
||||
head_dim = cos_sin_cache.shape[-1]
|
||||
|
||||
q_flat = q.view(batch_size * seq_len, -1)
|
||||
k_flat = k.view(batch_size * seq_len, -1)
|
||||
|
||||
position_ids = position_ids.to(q.device)
|
||||
|
||||
print("cos_sin_cache.shape", cos_sin_cache.shape)
|
||||
|
||||
query_rotated_flash, key_rotated_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
|
||||
position_ids, q_flat, k_flat, head_dim, cos_sin_cache, is_neox=is_neox
|
||||
)
|
||||
query_rotated_flash = query_rotated_flash.view(q_shape)
|
||||
key_rotated_flash = key_rotated_flash.view(k_shape)
|
||||
return query_rotated_flash, key_rotated_flash
|
||||
|
||||
|
||||
@apply_rope_with_input_pos_flashinfer.register_fake
|
||||
def apply_rope_with_input_pos_flashinfer_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return q, k
|
||||
@ -20,7 +20,6 @@ def apply_rotary_pos_emb(
|
||||
):
|
||||
"""
|
||||
Apply rotary positional embeddings to query and key tensors.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [batch, n_heads, seq_len, head_dim]
|
||||
k: Key tensor of shape [batch, n_kv_heads, seq_len, head_dim]
|
||||
@ -28,7 +27,6 @@ def apply_rotary_pos_emb(
|
||||
head_dim: Dimension of each head
|
||||
rope_theta: Base value for RoPE (default 10000.0)
|
||||
rope_scale: Scaling factor for positions (default 1.0)
|
||||
|
||||
Returns:
|
||||
Tuple of transformed query and key tensors
|
||||
"""
|
||||
@ -108,7 +106,6 @@ def fused_mha(
|
||||
rope_scale: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Fused MHA+Rope that takes raw input from q, k, v GEMMs.
|
||||
|
||||
Rope is performed according to the specified rope configuration. No support for caching.
|
||||
"""
|
||||
# b, s info
|
||||
|
||||
@ -7,6 +7,7 @@ from .fused_moe import *
|
||||
from .fusion import *
|
||||
from .kvcache import *
|
||||
from .quantization import *
|
||||
from .rope import *
|
||||
from .sharding import *
|
||||
|
||||
try:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Pattern matcher to detect MHA pattern and replace with simple fused_mha op."""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch._subclasses import FakeTensor
|
||||
@ -9,7 +9,7 @@ from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from ...models.factory import PositionalEmbeddingConfig
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import is_dist_op, is_linear_op, is_op
|
||||
from ...utils.node_utils import bfs, is_dist_op, is_linear_op, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
@ -20,16 +20,6 @@ def _is_dist_lin_op(node: Node, exclude: Optional[List[Node]] = None) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _bfs(node: Node, target: Callable, attr_next: str = "users") -> Node:
|
||||
queue = [node]
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if target(cur_node):
|
||||
return cur_node
|
||||
queue.extend(getattr(cur_node, attr_next))
|
||||
raise RuntimeError(f"Could not find node with target condition {target}.")
|
||||
|
||||
|
||||
def identify_and_fuse_mha(
|
||||
egm: GraphModule, pos_embd_config: PositionalEmbeddingConfig
|
||||
) -> GraphModule:
|
||||
@ -66,9 +56,9 @@ def identify_and_fuse_mha(
|
||||
# from the sdpa node, identify q, k, v, and out GEMMs via BFS
|
||||
for arg in mha_node.args[:3]:
|
||||
mha_gemms[mha_node].append(
|
||||
_bfs(arg, lambda n: _is_dist_lin_op(n, mha_gemms[mha_node]), "all_input_nodes")
|
||||
bfs(arg, lambda n: _is_dist_lin_op(n, mha_gemms[mha_node]), "all_input_nodes")
|
||||
)
|
||||
mha_gemms[mha_node].append(_bfs(mha_node, _is_dist_lin_op, "users"))
|
||||
mha_gemms[mha_node].append(bfs(mha_node, _is_dist_lin_op, "users"))
|
||||
|
||||
# get fake q tensor that is an MHA input node to retrieve head_dim
|
||||
q_fake: FakeTensor = mha_node.args[0].meta["val"]
|
||||
|
||||
508
tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py
Normal file
508
tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py
Normal file
@ -0,0 +1,508 @@
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_rope_v1(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Identify and replace legacy RoPE subgraphs (explicit cos/sin multiplication pattern):
|
||||
|
||||
output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin))
|
||||
|
||||
If exactly two such branches (query and key) are detected within each region, they're replaced
|
||||
by a call to `torch.ops.rope.flashinfer`.
|
||||
"""
|
||||
graph = gm.graph
|
||||
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
|
||||
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
|
||||
rope_position_ids_cache: Dict[str, Node] = {}
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
matches = []
|
||||
node = start_boundary
|
||||
while node != end_boundary:
|
||||
if is_op(node, torch.ops.aten.add):
|
||||
match_info = _match_rotary_subpattern_V1(node)
|
||||
if match_info:
|
||||
matches.append(match_info)
|
||||
node = node.next
|
||||
|
||||
if not matches:
|
||||
continue
|
||||
if len(matches) != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected exactly 2 legacy RoPE branches between {start_boundary} and {end_boundary}, "
|
||||
f"found {len(matches)}."
|
||||
)
|
||||
|
||||
# Assume the first matched branch is query (q), second is key (k).
|
||||
# This assumption is based on the default ordering in the exported graph,
|
||||
# since node naming conventions don't reliably indicate q/k branches.
|
||||
q_match, k_match = matches
|
||||
_process_rope_v1(
|
||||
graph,
|
||||
q_match,
|
||||
k_match,
|
||||
start_boundary,
|
||||
rope_flash_cache,
|
||||
rope_position_ids_cache,
|
||||
)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def match_rope_v2(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
Identify and replace RoPE subgraphs using complex multiplication pattern:
|
||||
|
||||
output = type_as(flatten(view_as_real(mul(view_as_complex(reshape(to_dtype(x))), unsqueeze(freqs_cis, 2)))), x)
|
||||
|
||||
If exactly two such branches (query and key) are detected within each region, they're replaced
|
||||
by a call to `torch.ops.rope.flashinfer`.
|
||||
"""
|
||||
graph = gm.graph
|
||||
boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm)
|
||||
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None)
|
||||
rope_position_ids_cache: Dict[str, Node] = {}
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
matches = []
|
||||
node = start_boundary
|
||||
while node != end_boundary:
|
||||
if is_op(node, torch.ops.aten.type_as):
|
||||
match_info = _match_rotary_subpattern_V2(node)
|
||||
if match_info:
|
||||
matches.append(match_info)
|
||||
node = node.next
|
||||
|
||||
if not matches:
|
||||
continue
|
||||
if len(matches) != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected exactly 2 complex RoPE branches between {start_boundary} and {end_boundary}, "
|
||||
f"found {len(matches)}."
|
||||
)
|
||||
|
||||
# Assume the first matched branch is query (q), second is key (k).
|
||||
# This assumption is based on the default ordering in the exported graph,
|
||||
# since node naming conventions don't reliably indicate q/k branches.
|
||||
q_match, k_match = matches
|
||||
_process_rope_v2(
|
||||
graph,
|
||||
q_match,
|
||||
k_match,
|
||||
rope_flash_cache,
|
||||
rope_position_ids_cache,
|
||||
)
|
||||
|
||||
gm = canonicalize_graph(gm)
|
||||
return gm
|
||||
|
||||
|
||||
def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Given an aten.add.Tensor node that is expected to compute:
|
||||
output = (raw_input * unsqueeze(cos)) + (rotate_half(raw_input) * unsqueeze(sin))
|
||||
where rotate_half is implemented as:
|
||||
rotate_half(x) = cat([ -slice(x, second_half), slice(x, first_half) ], dim=-1)
|
||||
this function inspects the structure of add_node and returns a dictionary with:
|
||||
- "raw_input": the original q/k tensor,
|
||||
- "unsqueeze_cos": the unsqueeze node feeding the raw multiplication,
|
||||
- "unsqueeze_sin": the unsqueeze node feeding the rotated multiplication,
|
||||
- "add_node": the addition node itself.
|
||||
Returns None if the pattern does not match.
|
||||
"""
|
||||
# Check that add_node is an add operation with two inputs.
|
||||
if not is_op(add_node, torch.ops.aten.add):
|
||||
return None
|
||||
if not (len(add_node.args) == 2):
|
||||
return None
|
||||
|
||||
mul1, mul2 = add_node.args
|
||||
# Both inputs to the add should be multiplications.
|
||||
if not is_op(mul1, torch.ops.aten.mul):
|
||||
return None
|
||||
if not is_op(mul2, torch.ops.aten.mul):
|
||||
return None
|
||||
|
||||
# One branch should be the raw branch and the other the rotated branch.
|
||||
# We decide by checking if one multiplication’s first argument is a cat (i.e. the rotate_half result).
|
||||
if is_op(mul1.args[0], torch.ops.aten.cat):
|
||||
mul_rot = mul1
|
||||
mul_raw = mul2
|
||||
elif is_op(mul2.args[0], torch.ops.aten.cat):
|
||||
mul_rot = mul2
|
||||
mul_raw = mul1
|
||||
else:
|
||||
return None
|
||||
|
||||
# Verify that both multiplications have an unsqueeze as their second argument.
|
||||
unsqueeze_cos = mul_raw.args[1]
|
||||
unsqueeze_sin = mul_rot.args[1]
|
||||
if not is_op(unsqueeze_cos, torch.ops.aten.unsqueeze):
|
||||
return None
|
||||
if not is_op(unsqueeze_sin, torch.ops.aten.unsqueeze):
|
||||
return None
|
||||
|
||||
# Check that the rotated branch is a cat of two tensors along -1.
|
||||
cat_node = mul_rot.args[0]
|
||||
if not is_op(cat_node, torch.ops.aten.cat):
|
||||
return None
|
||||
# Expecting two inputs in a list/tuple.
|
||||
cat_inputs = cat_node.args[0]
|
||||
if not (isinstance(cat_inputs, (list, tuple)) and len(cat_inputs) == 2):
|
||||
return None
|
||||
|
||||
# One of the two inputs should be a negation of a slice, the other should be a slice.
|
||||
first_item, second_item = cat_inputs
|
||||
if not is_op(first_item, torch.ops.aten.neg):
|
||||
return None
|
||||
if not is_op(second_item, torch.ops.aten.slice):
|
||||
return None
|
||||
|
||||
# The negation node should wrap a slice.
|
||||
neg_node = first_item
|
||||
if not (len(neg_node.args) >= 1 and is_op(neg_node.args[0], torch.ops.aten.slice)):
|
||||
return None
|
||||
|
||||
# For simplicity, require that the two slice operations (the one inside neg and the one used directly)
|
||||
# are applied on the same original tensor. This original tensor is the one being rotated.
|
||||
slice_in_neg = neg_node.args[0]
|
||||
if slice_in_neg.args[0] != second_item.args[0]:
|
||||
return None
|
||||
|
||||
# Finally, the raw branch should multiply the original tensor (i.e. q or k) by unsqueeze_cos.
|
||||
raw_input = mul_raw.args[0]
|
||||
# We also expect that the tensor being sliced (and negated) is the same as raw_input.
|
||||
if raw_input != slice_in_neg.args[0]:
|
||||
return None
|
||||
|
||||
return {
|
||||
"raw_input": raw_input,
|
||||
"unsqueeze_cos": unsqueeze_cos,
|
||||
"unsqueeze_sin": unsqueeze_sin,
|
||||
"add_node": add_node,
|
||||
}
|
||||
|
||||
|
||||
def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]:
|
||||
"""
|
||||
Given a type_as node, this function inspects the graph
|
||||
structure and returns a dictionary with:
|
||||
- "input": the original xq (or xk) tensor,
|
||||
- "inv_freq": the freqs_cis tensor (before unsqueeze),
|
||||
- "out": the type_as node corresponding to the branch output.
|
||||
|
||||
Expected branch structure for each output:
|
||||
x_out = type_as( flatten( view_as_real( view_as_complex(reshape(to_dtype(x))) * unsqueeze(freqs_cis) ) ) )
|
||||
|
||||
Returns None if the structure does not match.
|
||||
"""
|
||||
if not is_op(type_as_node, torch.ops.aten.type_as):
|
||||
return None
|
||||
|
||||
# The type_as node should have at least one argument: its first argument is the flatten op.
|
||||
if not (len(type_as_node.args) >= 1):
|
||||
return None
|
||||
flatten_node = type_as_node.args[0]
|
||||
if not is_op(flatten_node, torch.ops.aten.flatten):
|
||||
return None
|
||||
|
||||
# The input of the flatten op should be a view_as_real op.
|
||||
if not (len(flatten_node.args) >= 1):
|
||||
return None
|
||||
view_as_real_node = flatten_node.args[0]
|
||||
if not is_op(view_as_real_node, torch.ops.aten.view_as_real):
|
||||
return None
|
||||
|
||||
# The input of view_as_real should be a multiplication.
|
||||
if not (len(view_as_real_node.args) >= 1):
|
||||
return None
|
||||
mul_node = view_as_real_node.args[0]
|
||||
if not is_op(mul_node, torch.ops.aten.mul):
|
||||
return None
|
||||
if len(mul_node.args) != 2:
|
||||
return None
|
||||
|
||||
# In the multiplication, one operand should be an unsqueeze of freqs_cis and
|
||||
# the other operand is the output of view_as_complex.
|
||||
if is_op(mul_node.args[0], torch.ops.aten.unsqueeze):
|
||||
unsqueeze_node = mul_node.args[0]
|
||||
vc_node = mul_node.args[1]
|
||||
elif is_op(mul_node.args[1], torch.ops.aten.unsqueeze):
|
||||
unsqueeze_node = mul_node.args[1]
|
||||
vc_node = mul_node.args[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Verify that the unsqueeze is performed along dimension 2.
|
||||
if not (len(unsqueeze_node.args) >= 2 and unsqueeze_node.args[1] == 2):
|
||||
return None
|
||||
inv_freq_candidate = unsqueeze_node.args[0]
|
||||
|
||||
# Match the view_as_complex branch.
|
||||
if not is_op(vc_node, torch.ops.aten.view_as_complex):
|
||||
return None
|
||||
if not (len(vc_node.args) >= 1):
|
||||
return None
|
||||
reshape_node = vc_node.args[0]
|
||||
if not is_op(reshape_node, torch.ops.aten.reshape):
|
||||
return None
|
||||
|
||||
# The reshape op should get its input from a to(dtype) conversion.
|
||||
if not (len(reshape_node.args) >= 1):
|
||||
return None
|
||||
to_node = reshape_node.args[0]
|
||||
if not is_op(to_node, torch.ops.aten.to):
|
||||
return None
|
||||
if not (len(to_node.args) >= 1):
|
||||
return None
|
||||
input_tensor = to_node.args[0]
|
||||
|
||||
return {
|
||||
"input": input_tensor,
|
||||
"inv_freq": inv_freq_candidate,
|
||||
"out": type_as_node,
|
||||
}
|
||||
|
||||
|
||||
def _process_rope_v1(
|
||||
graph: GraphModule,
|
||||
q_match: Dict[str, Node],
|
||||
k_match: Dict[str, Node],
|
||||
start_boundary: Node,
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]],
|
||||
rope_position_ids_cache: Dict[str, Node],
|
||||
) -> None:
|
||||
"""
|
||||
Process a region that matched the legacy RoPE pattern (v1).
|
||||
Inserts the custom op (flashinfer) and replaces the original add nodes.
|
||||
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
|
||||
and reuses those nodes when possible.
|
||||
"""
|
||||
q_node = q_match["raw_input"]
|
||||
k_node = k_match["raw_input"]
|
||||
cos_node = q_match["unsqueeze_cos"].args[0]
|
||||
sin_node = q_match["unsqueeze_sin"].args[0]
|
||||
|
||||
# Sanity-check: ensure cos/sin nodes trace back to aten.cos/aten.sin.
|
||||
bfs(
|
||||
cos_node,
|
||||
lambda n: is_op(n, torch.ops.aten.cos),
|
||||
attr_next="all_input_nodes",
|
||||
boundary=start_boundary,
|
||||
)
|
||||
bfs(
|
||||
sin_node,
|
||||
lambda n: is_op(n, torch.ops.aten.sin),
|
||||
attr_next="all_input_nodes",
|
||||
boundary=start_boundary,
|
||||
)
|
||||
|
||||
# Infer input layout; default to [b, n, s, d] if inference fails.
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
need_transpose = isinstance(q_fake.shape[1], int)
|
||||
ad_logger.debug(
|
||||
f"Inferred RoPE input layout: [{'[b, n, s, d]' if need_transpose else '[b, s, n, d]'}]"
|
||||
)
|
||||
# Additional sanity check for the third dimension
|
||||
if need_transpose:
|
||||
if not isinstance(q_fake.shape[2], torch.SymInt):
|
||||
ad_logger.warning(
|
||||
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
|
||||
)
|
||||
need_transpose = True
|
||||
else:
|
||||
if not isinstance(q_fake.shape[1], torch.SymInt):
|
||||
ad_logger.warning(
|
||||
"Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]"
|
||||
)
|
||||
need_transpose = True
|
||||
else:
|
||||
ad_logger.warning("Unable to infer layout of q node. Defaulting to [b, n, s, d].")
|
||||
need_transpose = True
|
||||
|
||||
with graph.inserting_before(q_match["add_node"]):
|
||||
if need_transpose:
|
||||
q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2))
|
||||
k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2))
|
||||
q_for_op_contig = graph.call_method("contiguous", (q_for_op,))
|
||||
k_for_op_contig = graph.call_method("contiguous", (k_for_op,))
|
||||
else:
|
||||
q_for_op_contig, k_for_op_contig = q_node, k_node
|
||||
|
||||
head_dim = cos_node.meta["val"].shape[-1]
|
||||
half_head_dim = head_dim // 2
|
||||
|
||||
cache = rope_flash_cache
|
||||
cache_key = (cos_node, sin_node)
|
||||
if cache_key in cache:
|
||||
fused_cos_sin = cache[cache_key]
|
||||
else:
|
||||
cos_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim)
|
||||
)
|
||||
sin_prefix = graph.call_function(
|
||||
torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim)
|
||||
)
|
||||
fused_cos_sin = graph.call_function(
|
||||
torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1)
|
||||
)
|
||||
fused_cos_sin = graph.call_function(operator.getitem, args=(fused_cos_sin, 0))
|
||||
fused_cos_sin = graph.call_method("to", (fused_cos_sin, torch.float32))
|
||||
cache[cache_key] = fused_cos_sin
|
||||
|
||||
position_ids = _get_position_ids(
|
||||
graph,
|
||||
q_for_op_contig,
|
||||
batch_dim=0,
|
||||
seq_dim=1,
|
||||
rope_position_ids_cache=rope_position_ids_cache,
|
||||
)
|
||||
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_for_op_contig, k_for_op_contig, position_ids, fused_cos_sin, True),
|
||||
)
|
||||
|
||||
with graph.inserting_after(flash_node):
|
||||
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
|
||||
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
|
||||
|
||||
if need_transpose:
|
||||
with graph.inserting_after(raw_q):
|
||||
new_q = graph.call_function(torch.ops.aten.transpose, args=(raw_q, 1, 2))
|
||||
with graph.inserting_after(raw_k):
|
||||
new_k = graph.call_function(torch.ops.aten.transpose, args=(raw_k, 1, 2))
|
||||
else:
|
||||
new_q, new_k = raw_q, raw_k
|
||||
|
||||
new_q.meta["val"] = q_match["add_node"].meta.get("val", None)
|
||||
new_k.meta["val"] = k_match["add_node"].meta.get("val", None)
|
||||
|
||||
q_match["add_node"].replace_all_uses_with(new_q)
|
||||
k_match["add_node"].replace_all_uses_with(new_k)
|
||||
|
||||
|
||||
def _process_rope_v2(
|
||||
graph: GraphModule,
|
||||
q_match: Dict[str, Node],
|
||||
k_match: Dict[str, Node],
|
||||
rope_flash_cache: DefaultDict[Any, Optional[Node]],
|
||||
rope_position_ids_cache: Dict[str, Node],
|
||||
) -> None:
|
||||
"""
|
||||
Process a region that matched the complex-multiplication RoPE pattern (v2).
|
||||
Inserts the custom op (flashinfer) after extracting frequency information,
|
||||
and replaces the original type_as nodes.
|
||||
Precomputes positional IDs and the fused cosine-sine cache as explicit nodes,
|
||||
and reuses those nodes when possible.
|
||||
"""
|
||||
q_node = q_match["input"]
|
||||
k_node = k_match["input"]
|
||||
inv_freq_node = q_match["inv_freq"]
|
||||
|
||||
if inv_freq_node != k_match["inv_freq"]:
|
||||
raise RuntimeError("Mismatch of freqs_cis (inv_freq) between branches.")
|
||||
|
||||
# Sanity check that input layout is BSND (no transpose needed).
|
||||
q_fake = q_node.meta.get("val", None)
|
||||
if q_fake is not None and len(q_fake.shape) > 2:
|
||||
if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)):
|
||||
ad_logger.warning(
|
||||
f"""Sanity check failed: q_fake should have shape [b, s, n, d],
|
||||
s should be symbolic and n should be int, instead got shape {q_fake.shape}"""
|
||||
)
|
||||
else:
|
||||
ad_logger.warning(
|
||||
f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}"
|
||||
)
|
||||
|
||||
# Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash
|
||||
cache = rope_flash_cache
|
||||
if inv_freq_node in cache:
|
||||
cos_sin_flash = cache[inv_freq_node]
|
||||
else:
|
||||
# Compute the fused cosine/sine cache.
|
||||
with graph.inserting_after(inv_freq_node):
|
||||
real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,))
|
||||
imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,))
|
||||
with graph.inserting_after(real_part):
|
||||
cos_sin_flash_3d = graph.call_function(
|
||||
torch.ops.aten.cat, args=((real_part, imag_part), -1)
|
||||
)
|
||||
with graph.inserting_after(cos_sin_flash_3d):
|
||||
cos_sin_flash = graph.call_function(operator.getitem, args=(cos_sin_flash_3d, 0))
|
||||
with graph.inserting_after(cos_sin_flash):
|
||||
cos_sin_flash = graph.call_method("to", (cos_sin_flash, torch.float32))
|
||||
cache[inv_freq_node] = cos_sin_flash
|
||||
|
||||
with graph.inserting_before(q_match["out"]):
|
||||
position_ids = _get_position_ids(
|
||||
graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=rope_position_ids_cache
|
||||
)
|
||||
flash_node = graph.call_function(
|
||||
torch.ops.rope.flashinfer,
|
||||
args=(q_node, k_node, position_ids, cos_sin_flash, False),
|
||||
)
|
||||
|
||||
with graph.inserting_after(flash_node):
|
||||
raw_q = graph.call_function(operator.getitem, args=(flash_node, 0))
|
||||
raw_k = graph.call_function(operator.getitem, args=(flash_node, 1))
|
||||
|
||||
raw_q.meta["val"] = q_match["out"].meta.get("val", None)
|
||||
raw_k.meta["val"] = k_match["out"].meta.get("val", None)
|
||||
|
||||
q_match["out"].replace_all_uses_with(raw_q)
|
||||
k_match["out"].replace_all_uses_with(raw_k)
|
||||
|
||||
|
||||
def _get_position_ids(
|
||||
graph: GraphModule,
|
||||
q_node: Node,
|
||||
batch_dim: int = 0,
|
||||
seq_dim: int = 1,
|
||||
rope_position_ids_cache: Dict[str, Node] = None,
|
||||
) -> Node:
|
||||
"""
|
||||
Retrieves the cached position_ids from the graph if available, or computes and caches them.
|
||||
It uses the symbolic batch and sequence sizes from q_node with the provided dimension indices.
|
||||
"""
|
||||
if rope_position_ids_cache is None:
|
||||
rope_position_ids_cache = {}
|
||||
|
||||
if "position_ids" in rope_position_ids_cache:
|
||||
return rope_position_ids_cache["position_ids"]
|
||||
|
||||
sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, batch_dim))
|
||||
sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, seq_dim))
|
||||
|
||||
# Retrieve device information, ensuring it is a torch.device.
|
||||
device = q_node.meta.get("device", "cpu")
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
|
||||
# Build positions: arange(sym_seq) -> view -> expand -> flatten.
|
||||
positions_node = graph.call_function(
|
||||
torch.ops.aten.arange,
|
||||
args=(sym_seq,),
|
||||
kwargs={"dtype": torch.float32, "device": device, "pin_memory": False},
|
||||
)
|
||||
positions_node = graph.call_function(torch.ops.aten.view, args=(positions_node, (1, -1)))
|
||||
positions_node = graph.call_function(
|
||||
torch.ops.aten.expand, args=(positions_node, (sym_batch, -1))
|
||||
)
|
||||
position_ids = graph.call_function(torch.ops.aten.flatten, args=(positions_node,))
|
||||
rope_position_ids_cache["position_ids"] = position_ids
|
||||
return position_ids
|
||||
@ -1,7 +1,7 @@
|
||||
"""Common utils for torch fx graph transformation."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
@ -287,3 +287,23 @@ def identify_regions_between_residuals(gm: GraphModule) -> List[Node]:
|
||||
boundary_nodes.append(output_node)
|
||||
|
||||
return boundary_nodes
|
||||
|
||||
|
||||
def bfs(
|
||||
node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None
|
||||
) -> Node:
|
||||
queue = [node]
|
||||
visited = set()
|
||||
while queue:
|
||||
cur_node = queue.pop(0)
|
||||
if boundary is not None and cur_node == boundary:
|
||||
continue # Skip the boundary node.
|
||||
if target(cur_node):
|
||||
return cur_node
|
||||
for next_node in getattr(cur_node, attr_next):
|
||||
if boundary is not None and next_node == boundary:
|
||||
continue # Do not expand past the boundary.
|
||||
if next_node not in visited:
|
||||
visited.add(next_node)
|
||||
queue.append(next_node)
|
||||
raise RuntimeError(f"Could not find node with target condition {target}.")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import copy
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -34,6 +34,7 @@ def run_test(
|
||||
rtol: float = 1e-3,
|
||||
test_load_hook: bool = True,
|
||||
strict_loading: bool = True,
|
||||
dynamic_shapes: Dict = None,
|
||||
*args, # Additional arguments for transform
|
||||
) -> GraphModule:
|
||||
# run model once
|
||||
@ -44,7 +45,7 @@ def run_test(
|
||||
print(num_params_model)
|
||||
|
||||
# export + check (we clone the state dict to have a bit more freedom in testing below)
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
|
||||
print(gm)
|
||||
y_gm = gm(x)
|
||||
num_params_gm = count_parameters(gm)
|
||||
|
||||
@ -0,0 +1,145 @@
|
||||
from typing import Tuple
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy # noqa: F401
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-4, 1e-4),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_and_custom_rope_ops(dtype, atol, rtol, head_dim):
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
# Prepare rotary embedding values.
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2]
|
||||
cos_vals = torch.cos(angles) # [seq_len, head_dim//2]
|
||||
sin_vals = torch.sin(angles) # [seq_len, head_dim//2]
|
||||
|
||||
# For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated).
|
||||
cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1)
|
||||
# For HF and the custom op: duplicated layout [seq_len, head_dim].
|
||||
cos_new = torch.cat([cos_vals, cos_vals], dim=-1)
|
||||
sin_new = torch.cat([sin_vals, sin_vals], dim=-1)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Direct FlashInfer kernel call.
|
||||
query_flat = query.view(batch * seq_len, n_head * head_dim)
|
||||
key_flat = key.view(batch * seq_len, n_head * head_dim)
|
||||
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
|
||||
q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache(
|
||||
positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True
|
||||
)
|
||||
q_flash = q_flash.view(batch, seq_len, n_head, head_dim)
|
||||
k_flash = k_flash.view(batch, seq_len, n_head, head_dim)
|
||||
|
||||
# HF implementation using apply_rotary_pos_emb.
|
||||
# HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1
|
||||
q_for_hf = query.transpose(1, 2).clone()
|
||||
k_for_hf = key.transpose(1, 2).clone()
|
||||
cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim]
|
||||
q_hf, k_hf = apply_rotary_pos_emb(q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1)
|
||||
|
||||
# Convert outputs to [batch, seq_len, n_head, head_dim]
|
||||
q_hf = q_hf.transpose(1, 2).to(dtype)
|
||||
k_hf = k_hf.transpose(1, 2).to(dtype)
|
||||
|
||||
# Custom op call
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, True)
|
||||
|
||||
torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
# Version 2: complex multiplication approach
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64
|
||||
@pytest.mark.parametrize(
|
||||
"dtype,atol,rtol",
|
||||
[
|
||||
(torch.bfloat16, 1e-5, 1e-5),
|
||||
(torch.float16, 5e-4, 5e-4),
|
||||
],
|
||||
ids=["bfloat16", "float16"], # q/k must be in half precision
|
||||
)
|
||||
def test_flashinfer_complex_rotary(dtype, atol, rtol, head_dim):
|
||||
device = "cuda"
|
||||
batch = 2
|
||||
seq_len = 4
|
||||
n_head = 3
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
10000
|
||||
** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2))
|
||||
)
|
||||
positions_range = torch.arange(seq_len, dtype=torch.float32, device=device)
|
||||
angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2)
|
||||
freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles)
|
||||
freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2)
|
||||
|
||||
query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device)
|
||||
|
||||
out_q_v2, out_k_v2 = apply_rotary_emb(query, key, freqs_cis)
|
||||
|
||||
cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2)
|
||||
sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2)
|
||||
cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim))
|
||||
|
||||
# q/k of llama4 rope is interleaved
|
||||
positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)])
|
||||
custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, False)
|
||||
|
||||
torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol)
|
||||
@ -0,0 +1,211 @@
|
||||
import pytest
|
||||
import torch
|
||||
from _graph_test_helpers import run_test
|
||||
from torch.export import Dim
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.transformations.library.rope import (
|
||||
match_rope_v1,
|
||||
match_rope_v2,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1
|
||||
):
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def _precompute_freqs_cis(seq_len: int, head_dim: int, rope_theta: float):
|
||||
dtype = torch.float32
|
||||
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos().to(dtype)
|
||||
sin = emb.sin().to(dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype.
|
||||
):
|
||||
# Reshape the inputs to pair the last dimension.
|
||||
xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
||||
# Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim.
|
||||
xq_out = torch.view_as_real(xq_complex * freqs_cis[:, :, None, :]).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_complex * freqs_cis[:, :, None, :]).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
def _precompute_freqs_cis_v2(seq_len: int, head_dim: int, rope_theta: float):
|
||||
"""
|
||||
Compute the frequency tensor for the complex multiplication RoPE variant.
|
||||
Returns a complex tensor of shape (seq_len, head_dim//2).
|
||||
"""
|
||||
inv_freq = 1.0 / (
|
||||
rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2))
|
||||
)
|
||||
positions = torch.arange(seq_len, dtype=torch.float32)
|
||||
angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2)
|
||||
# Create a complex tensor from magnitude=1 and the computed angles.
|
||||
freqs_cis = torch.polar(torch.ones_like(angles), angles)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class RotaryModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
layout: str = "BNSD",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.layout = layout
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
|
||||
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
q = self.linear_q(x)
|
||||
k = self.linear_k(x)
|
||||
|
||||
batch, seq, _ = q.shape
|
||||
q = q.view(batch, seq, self.num_heads, self.head_dim)
|
||||
k = k.view(batch, seq, self.num_kv_heads, self.head_dim)
|
||||
|
||||
if self.layout == "BNSD":
|
||||
q = q.permute(0, 2, 1, 3).contiguous() # [B, N, S, D]
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
unsqueeze_dim = 1
|
||||
else: # BSND
|
||||
unsqueeze_dim = 2
|
||||
|
||||
cos, sin = _precompute_freqs_cis(seq, self.head_dim, rope_theta=10000)
|
||||
cos = cos.to(q.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
sin = sin.to(q.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
|
||||
q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim)
|
||||
if self.layout == "BNSD":
|
||||
# [B, N, S, D] -> [B, S, N*D]
|
||||
q_embed = q_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
|
||||
k_embed = k_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1)
|
||||
else: # BSND
|
||||
q_embed = q_embed.reshape(batch, seq, -1)
|
||||
k_embed = k_embed.reshape(batch, seq, -1)
|
||||
|
||||
output = torch.cat([q_embed, k_embed], dim=-1)
|
||||
return output.to(torch.float16)
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
class RotaryModelV2(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, max_seq_len: int, num_heads: int, num_kv_heads: int):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
|
||||
self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim)
|
||||
self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch, seq, _ = x.shape
|
||||
|
||||
q = self.linear_q(x).view(batch, seq, self.num_heads, self.head_dim)
|
||||
k = self.linear_k(x).view(batch, seq, self.num_kv_heads, self.head_dim)
|
||||
|
||||
freqs_cis = _precompute_freqs_cis_v2(seq, self.head_dim, rope_theta=10000)
|
||||
freqs_cis = freqs_cis.to(x.device).unsqueeze(0).expand(batch, -1, -1)
|
||||
|
||||
q_embed, k_embed = apply_rotary_emb(q, k, freqs_cis)
|
||||
|
||||
q_embed = q_embed.reshape(batch, seq, -1)
|
||||
k_embed = k_embed.reshape(batch, seq, -1)
|
||||
|
||||
output = torch.cat([q_embed, k_embed], dim=-1)
|
||||
return output.to(torch.float16)
|
||||
|
||||
def get_dynamic_shapes(self):
|
||||
return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layout", ["BNSD", "BSND"])
|
||||
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
|
||||
@torch.inference_mode()
|
||||
def test_match_rope(layout, num_heads, num_kv_heads):
|
||||
batch_size, seq_len = 8, 16
|
||||
hidden_size = 512
|
||||
max_position_embeddings = seq_len
|
||||
|
||||
model = RotaryModel(
|
||||
hidden_size, max_position_embeddings, num_heads, num_kv_heads, layout=layout
|
||||
).to("cuda", dtype=torch.float16)
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_rope_v1,
|
||||
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)])
|
||||
@torch.inference_mode()
|
||||
def test_match_rope_v2(num_heads, num_kv_heads):
|
||||
batch_size, seq_len = 8, 16
|
||||
hidden_size = 512
|
||||
max_position_embeddings = seq_len
|
||||
|
||||
model = RotaryModelV2(hidden_size, max_position_embeddings, num_heads, num_kv_heads).to(
|
||||
"cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16)
|
||||
dynamic_shapes = model.get_dynamic_shapes()
|
||||
|
||||
_ = run_test(
|
||||
model,
|
||||
x,
|
||||
match_rope_v2,
|
||||
lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes),
|
||||
lambda num_p_og: num_p_og,
|
||||
atol=1e-3,
|
||||
rtol=1e-3,
|
||||
test_load_hook=True,
|
||||
strict_loading=True,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user