diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index d5a5374ad3..3236e0267c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -10,4 +10,5 @@ from .mla import * from .quant import * from .rope import * from .torch_attention import * +from .torch_rope import * from .triton_attention import * diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py index 8d12d00d00..4746e6fb12 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_rope.py @@ -21,10 +21,11 @@ def apply_rope_with_input_pos_flashinfer( Tensors of shape [batch, seq_len, n_head, head_dim] (or a 3D variant) in half precision. Note: head_dim must be a multiple of 64. - position_ids (torch.Tensor): - Precomputed tensor of positional indices; it is shared across calls in the graph. + Precomputed tensor of positional indices indicating idx in cos_sin_cache for each token; + Shape [batch, seq_len] or [batch * seq_len] - cos_sin_cache (torch.Tensor): Precomputed fused tensor created by concatenating the first half of the cosine and sine - components derived from the inv_freq. + components derived from the inv_freq. Shape [max_seq_len, head_dim]. Must be float32. - is_neox (bool): Flag to indicate whether to invoke the FlashInfer kernel in Neox mode. @@ -35,14 +36,13 @@ def apply_rope_with_input_pos_flashinfer( """ q_shape = q.shape k_shape = k.shape - batch_size, seq_len = q_shape[:2] - head_dim = cos_sin_cache.shape[-1] - q_flat = q.view(batch_size * seq_len, -1) - k_flat = k.view(batch_size * seq_len, -1) + position_ids = position_ids.view(-1).to(q.device) + num_nnz = position_ids.shape[0] - position_ids = position_ids.to(q.device) + q_flat = q.view(num_nnz, -1) + k_flat = k.view(num_nnz, -1) query_rotated_flash, key_rotated_flash = flashinfer.rope.apply_rope_with_cos_sin_cache( position_ids, q_flat, k_flat, head_dim, cos_sin_cache, is_neox=is_neox @@ -60,4 +60,4 @@ def apply_rope_with_input_pos_flashinfer_fake( cos_sin_cache: torch.Tensor, is_neox: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: - return q, k + return torch.empty_like(q), torch.empty_like(k) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index 1a2209bb3a..c077d66b58 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -17,7 +17,6 @@ from .attention_interface import ( PrepareMetadataCallable, SequenceInfo, ) -from .torch_attention import apply_rotary_pos_emb_ds from .triton_attention import _flattened_context_mha, _generate_mha Constant = Union[int, float, str, None] @@ -74,26 +73,38 @@ def fused_flattened_mla_with_cache( # Apply RoPE if cos_sin_stacked.numel() > 0: # Extract cos and sin from freqs_cis - cos = cos_sin_stacked[0, ...] - sin = cos_sin_stacked[1, ...] + cos_base = cos_sin_stacked[0, ...] + sin_base = cos_sin_stacked[1, ...] # TODO: Use triton kernels for RoPE # TODO: Add yarn support - for idx in range(seq_len.shape[0]): - ( - q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - ) = apply_rotary_pos_emb_ds( - q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], + for i in range(seq_len.shape[0]): + start = seq_start[i] + length = seq_len[i] + + # build position_ids + if s == 1: + idx = (input_pos[i] + length - 1).item() + pos_ids = torch.tensor(idx, device=cos_base.device) + else: + pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device) + + cos = cos_base[pos_ids] # [..., 1, head_dim] + sin = sin_base[pos_ids] + q_slice = q_pe[start : start + length] + k_slice = k_pe[start : start + length] + + q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q_slice, + k_slice, cos, sin, - torch.arange(input_pos[idx] + seq_len[idx])[-1] - if s == 1 - else torch.arange(input_pos[idx] + seq_len[idx]), -2, ) + q_pe[start : start + length] = q_rot + k_pe[start : start + length] = k_rot + # Create query_states, key_states query_states = torch.cat((q_nope, q_pe), dim=-1) # [b*s,n,d] key_states = torch.cat((k_nope, k_pe.expand(*bs_view, num_heads, -1)), dim=-1) # [b*s,n,d] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index de3d309d1e..1c9b9f14c6 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -138,92 +138,6 @@ def bsnd_grouped_sdpa_fake( return torch.empty_like(query.contiguous()) -# Function to apply rotary positional embeddings (RoPE) -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - seq_len: int, - head_dim: int, - rope_theta: Optional[float] = None, - rope_scale: Optional[float] = None, -): - """ - Apply rotary positional embeddings to query and key tensors. - Args: - q: Query tensor of shape [batch, n_heads, seq_len, head_dim] - k: Key tensor of shape [batch, n_kv_heads, seq_len, head_dim] - seq_len: Sequence length - head_dim: Dimension of each head - rope_theta: Base value for RoPE (default 10000.0) - rope_scale: Scaling factor for positions (default 1.0) - Returns: - Tuple of transformed query and key tensors - """ - device = q.device - original_dtype = q.dtype - - # Apply default values if None - theta = 10000.0 if rope_theta is None else rope_theta - scale = 1.0 if rope_scale is None else rope_scale - - # Generate position indices - position = torch.arange(seq_len, device=device).float() - # Apply scaling factor to positions if provided - if scale != 1.0: - position = position / scale - - # Create the frequency matrix - ensure stable computation in float32 - inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) - # Compute the product of positions and frequencies - # Shape: [seq_len, head_dim/2] - freqs = torch.outer(position, inv_freq) - - # Compute the rotation matrix elements: cos and sin - # Shape: [seq_len, head_dim/2] - emb = torch.cat((freqs, freqs), dim=-1) - # Ensure stable computation of sin/cos in float32 - cos = torch.cos(emb).to(dtype=torch.float32) - sin = torch.sin(emb).to(dtype=torch.float32) - - # Reshape for broadcasting - # Shape: [1, 1, seq_len, head_dim] - cos = cos.view(1, 1, seq_len, head_dim) - sin = sin.view(1, 1, seq_len, head_dim) - - # Always compute in float32 for numerical stability - q_float = q.to(dtype=torch.float32) - k_float = k.to(dtype=torch.float32) - - # For the even indices of the dimension - q_embed_even = q_float[..., 0::2] - q_embed_odd = q_float[..., 1::2] - k_embed_even = k_float[..., 0::2] - k_embed_odd = k_float[..., 1::2] - - # Apply the rotation using the identities: - # q' = q * cos + rotate(q) * sin - # k' = k * cos + rotate(k) * sin - # where rotate(x) swaps the even and odd dimensions and negates the odd dimensions - q_rotated = torch.cat( - [ - q_embed_even * cos[..., 0::2] - q_embed_odd * sin[..., 0::2], - q_embed_odd * cos[..., 1::2] + q_embed_even * sin[..., 1::2], - ], - dim=-1, - ) - - k_rotated = torch.cat( - [ - k_embed_even * cos[..., 0::2] - k_embed_odd * sin[..., 0::2], - k_embed_odd * cos[..., 1::2] + k_embed_even * sin[..., 1::2], - ], - dim=-1, - ) - - # Convert back to the original dtype - return q_rotated.to(dtype=original_dtype), k_rotated.to(dtype=original_dtype) - - def update_kv_cache( key_states: torch.Tensor, value_states: torch.Tensor, @@ -248,46 +162,6 @@ def update_kv_cache( ) -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -@torch.inference_mode() -def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - - q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q) - k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - @torch.library.custom_op("attention::fused_mla_ref", mutates_args=()) def fused_mla_ref( q_nope: torch.Tensor, @@ -325,23 +199,33 @@ def fused_mla_ref( value_states = value_states.transpose(1, 2).view(*bs_view, -1, v_head_dim).contiguous() if freqs_cis is not None: - cos = freqs_cis[0, ...] - sin = freqs_cis[1, ...] - for idx in range(seq_len.shape[0]): - ( - q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - ) = apply_rotary_pos_emb_ds( - q_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], - k_pe[seq_start[idx] : seq_start[idx] + seq_len[idx], ...], + cos_base = freqs_cis[0, ...] + sin_base = freqs_cis[1, ...] + for i in range(seq_len.shape[0]): + start = seq_start[i] + length = seq_len[i] + if q_len == 1: + idx = (input_pos[i] + length - 1).item() + pos_ids = torch.tensor(idx, device=cos_base.device) + else: + pos_ids = torch.arange(input_pos[i], input_pos[i] + length, device=cos_base.device) + + cos = cos_base[pos_ids] # [..., 1, head_dim] + sin = sin_base[pos_ids] + q_slice = q_pe[start : start + length] + k_slice = k_pe[start : start + length] + + q_rot, k_rot = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q_slice, + k_slice, cos, sin, - torch.arange(input_pos[idx] + seq_len[idx])[-1] - if q_len == 1 - else torch.arange(input_pos[idx] + seq_len[idx]), -2, ) + q_pe[start : start + length] = q_rot + k_pe[start : start + length] = k_rot + query_states = k_pe.new_empty(*bs_view, num_heads, q_head_dim) # [b*s,n,d] query_states[..., :qk_nope_head_dim] = q_nope query_states[..., qk_nope_head_dim:] = q_pe @@ -454,7 +338,9 @@ def fused_mla( k_nope, value_states = torch.split(kv, [qk_nope_head_dim, v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] - q_pe, k_pe = apply_rotary_pos_emb_ds(q_pe, k_pe, cos, sin, position_ids) + cos = cos[position_ids] + sin = sin[position_ids] + q_pe, k_pe = torch.ops.rope.torch_apply_rope_with_qk_interleaving(q_pe, k_pe, cos, sin) query_states = k_pe.new_empty(bs, num_heads, q_len, q_head_dim) query_states[:, :, :, :qk_nope_head_dim] = q_nope diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py new file mode 100644 index 0000000000..33286d5946 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_rope.py @@ -0,0 +1,100 @@ +from typing import Tuple + +import torch + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@torch.library.custom_op("rope::torch_apply_rope_with_explicit_cos_sin", mutates_args=()) +def torch_apply_rope_with_explicit_cos_sin( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference PyTorch implementation of HF-style RoPE: + - Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and + [B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D] + - Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim]. + """ + # in HF, cos/sin tensor are passed in as x.dtype, this is to double ensure + cos = cos.type_as(q) + sin = sin.type_as(q) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +@torch_apply_rope_with_explicit_cos_sin.register_fake +def torch_apply_rope_with_explicit_cos_sin_fake( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) + + +@torch.library.custom_op("rope::torch_apply_rope_with_complex_freqs", mutates_args=()) +def torch_apply_rope_with_complex_freqs( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, # shape [B, S, head_dim//2] + unsqueeze_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Reference PyTorch implementation of interleaved (complex) RoPE: + - Input layout: [B, S, N, D] (interleaved) + - Frequencies are combined into a single complex-valued tensor `freqs_cis` + of shape [B, S, head_dim // 2]. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs = freqs_cis.unsqueeze(unsqueeze_dim) + xq_out = torch.view_as_real(xq_ * freqs).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +@torch_apply_rope_with_complex_freqs.register_fake +def torch_apply_rope_with_complex_freqs_fake( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, # shape [B, S, head_dim//2] + unsqueeze_dim: int = 2, +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(xq), torch.empty_like(xk) + + +@torch.library.custom_op("rope::torch_apply_rope_with_qk_interleaving", mutates_args=()) +def torch_apply_rope_with_qk_interleaving( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + DeepSeek-style RoPE: interleaves Q/K channels and returns rotated (q_embed, k_embed). + - Input layout: [B, S, N, D] or [B*S, N, D] or [B, N, S, D] + - Frequencies are provided as separate `cos` and `sin` tensors of shape + [B, S, 1, D] or [B*S, 1, D] or [B, 1, S, D] matching input shape. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # Rewrite below code to accept 3D input: + # b, h, s, d = q.shape + # q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + # b, h, s, d = k.shape + # k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + q = q.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(q) + k = k.unflatten(-1, (-1, 2)).transpose(-1, -2).reshape_as(k) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +@torch_apply_rope_with_qk_interleaving.register_fake +def torch_apply_rope_with_qk_interleaving_fake( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(q), torch.empty_like(k) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py index b3f59fafc8..b717a32359 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/fused_moe.py @@ -1,12 +1,12 @@ from collections import defaultdict -from typing import Callable, Optional +from typing import Optional import torch from torch.fx import GraphModule, Node from ...utils.cuda_mem_tracker import cuda_memory_tracker from ...utils.logger import ad_logger -from ...utils.node_utils import identify_regions_between_residuals, is_linear_op, is_op +from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op from .._graph import canonicalize_graph @@ -36,7 +36,7 @@ def match_moe_pattern(gm: GraphModule) -> GraphModule: common_ancessor2 = _find_lowest_common_ancessor(arg2_list) if not common_ancessor2: continue - selected_experts = _bfs( + selected_experts = bfs( common_ancessor2, lambda node: is_op(node, torch.ops.aten.one_hot), attr_next="all_input_nodes", @@ -153,26 +153,6 @@ def _insert_fused_moe_ops(gm: GraphModule): graph.erase_node(node) -def _bfs( - node: Node, target: Callable, attr_next: str = "users", boundary: Optional[Node] = None -) -> Node: - queue = [node] - visited = set() - while queue: - cur_node = queue.pop(0) - if boundary is not None and cur_node == boundary: - continue # Skip the boundary node. - if target(cur_node): - return cur_node - for next_node in getattr(cur_node, attr_next): - if boundary is not None and next_node == boundary: - continue # Do not expand past the boundary. - if next_node not in visited: - visited.add(next_node) - queue.append(next_node) - raise RuntimeError(f"Could not find node with target condition {target}.") - - def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]: """ Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following @@ -326,7 +306,7 @@ def _find_final_hidden_state_node( For each expert output node (from the expert compute pattern), this function: 1. Retrieves a multiplication node from its users. 2. Extracts the second argument from the multiplication node (assumed to be the index node). - 3. Uses a BFS (via _bfs) to locate the subsequent index_add_ node (guarded by the end_boundary). + 3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary). After collecting all such index_add_ nodes, the final hidden state node is determined as the one that is not used by any of the other index_add_ nodes. @@ -346,7 +326,7 @@ def _find_final_hidden_state_node( if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2): return None index_node = mul_node.args[1] - index_add_node = _bfs( + index_add_node = bfs( index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary ) if not index_add_node: @@ -412,7 +392,7 @@ def _remove_dead_inplace_nodes_in_region( return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0 try: - node_to_remove = _bfs(start_boundary, target, attr_next="users", boundary=end_boundary) + node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary) ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}") graph.erase_node(node_to_remove) return True diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py index da3970916b..48b3678b7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py @@ -1,86 +1,141 @@ +""" +This transformation defines two main RoPE (Rotary Positional Embedding) pattern matchers used +to identify and replace RoPE subgraphs with a custom op (`torch.ops.rope.flashinfer`). + +Supported RoPE variants: + +1. Explicit Cos/Sin Multiplication (HF-style, e.g., LLaMA, Mixtral, Qwen) + - Input layout: non-interleaved, [B, N, S, D] with unsqueeze_dim=1 and + [B, S, N, D] with unsqueeze_dim=2, default is [B, N, S, D] + - Frequencies are provided as separate `cos` and `sin` tensors of shape [B, S, head_dim]. + - Source code: + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +2. Complex Multiplication (GPTJ/Llama-stack-style, interleaved) + - Input layout: [B, S, N, D] (interleaved) + - Frequencies are combined into a single complex-valued tensor `freqs_cis` of shape [B, S, head_dim // 2]. + - Source code: + def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) + ) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + +Supported Minor variants: +- DeepSeekV3: reshape + transpose before applying RoPE. + dynamic position-based updates to frequency cache. + +TODO: Support other variants: +- Phi-4: rotary applied only to part of the hidden dimension (q_rot, q_pass split). +- LLaMA4 Vision: 2D rotary frequencies constructed from image patches. +""" + import operator from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Optional +from typing import Any, DefaultDict, Dict, List, Optional, Sequence import torch from torch.fx import GraphModule, Node from ...utils.logger import ad_logger -from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op +from ...utils.node_utils import bfs, extract_output_tuple, identify_regions_between_residuals, is_op from .._graph import canonicalize_graph -def match_rope_v1(gm: GraphModule) -> GraphModule: +def match_explicit_rope(gm: GraphModule) -> GraphModule: """ - Identify and replace legacy RoPE subgraphs (explicit cos/sin multiplication pattern): + Identify and replace RoPE subgraphs (explicit cos/sin multiplication pattern): - output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin)) + - HF-style: output = (raw * unsqueeze(cos)) + (rotate_half(raw) * unsqueeze(sin)) + - DS-style: requires interleaving Q/K before the cos/sin mul If exactly two such branches (query and key) are detected within each region, they're replaced - by a call to `torch.ops.rope.flashinfer`. + by a call to `rope::torch_apply_rope_with_qk_interleaving` or + `rope::torch_apply_rope_with_explicit_cos_sin` respectively. """ + ad_logger.info("Match explicit(HF) style RoPE") graph = gm.graph boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm) - rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None) - rope_position_ids_cache: Dict[str, Node] = {} - for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): - matches = [] + matches = [] # list of (match_info, is_ds) node = start_boundary while node != end_boundary: if is_op(node, torch.ops.aten.add): - match_info = _match_rotary_subpattern_V1(node) - if match_info: - matches.append(match_info) + explicit = _match_explicit_rope_subpattern(node) + if explicit is not None: + raw = explicit["raw_input"] + # check if this raw is result of DS interleave + inter = _match_input_interleave_pattern(raw) + if inter is not None: + ds_match = { + "raw_input": inter["interleaved"], + "unsqueeze_cos": explicit["unsqueeze_cos"], + "unsqueeze_sin": explicit["unsqueeze_sin"], + "add_node": explicit["add_node"], + } + matches.append((ds_match, True)) + else: + matches.append((explicit, False)) node = node.next if not matches: continue if len(matches) != 2: - raise RuntimeError( + ad_logger.warning( f"Expected exactly 2 legacy RoPE branches between {start_boundary} and {end_boundary}, " f"found {len(matches)}." ) + continue - # Assume the first matched branch is query (q), second is key (k). - # This assumption is based on the default ordering in the exported graph, - # since node naming conventions don't reliably indicate q/k branches. - q_match, k_match = matches - _process_rope_v1( - graph, - q_match, - k_match, - start_boundary, - rope_flash_cache, - rope_position_ids_cache, - ) + (q_match, q_is_ds), (k_match, k_is_ds) = matches + if q_is_ds != k_is_ds: + ad_logger.warning("Mismatched RoPE types between q and k branches") + continue + + if q_is_ds: + _process_input_interleave_rope(graph, q_match, k_match) + else: + _process_explicit_rope(graph, q_match, k_match, start_boundary) gm = canonicalize_graph(gm) return gm -def match_rope_v2(gm: GraphModule) -> GraphModule: +def match_complex_rope(gm: GraphModule) -> GraphModule: """ Identify and replace RoPE subgraphs using complex multiplication pattern: output = type_as(flatten(view_as_real(mul(view_as_complex(reshape(to_dtype(x))), unsqueeze(freqs_cis, 2)))), x) If exactly two such branches (query and key) are detected within each region, they're replaced - by a call to `torch.ops.rope.flashinfer`. + by a call to `torch.ops.rope.torch_apply_rope_with_complex_freqs`. """ + ad_logger.info("Match Complex style RoPE") graph = gm.graph boundary_nodes: List[torch.fx.Node] = identify_regions_between_residuals(gm) - rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None) - rope_position_ids_cache: Dict[str, Node] = {} - for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): matches = [] node = start_boundary while node != end_boundary: if is_op(node, torch.ops.aten.type_as): - match_info = _match_rotary_subpattern_V2(node) + match_info = _match_complex_rope_subpattern(node) if match_info: matches.append(match_info) node = node.next @@ -88,28 +143,344 @@ def match_rope_v2(gm: GraphModule) -> GraphModule: if not matches: continue if len(matches) != 2: - raise RuntimeError( + ad_logger.warning( f"Expected exactly 2 complex RoPE branches between {start_boundary} and {end_boundary}, " f"found {len(matches)}." ) + continue # Assume the first matched branch is query (q), second is key (k). # This assumption is based on the default ordering in the exported graph, # since node naming conventions don't reliably indicate q/k branches. q_match, k_match = matches - _process_rope_v2( - graph, - q_match, - k_match, - rope_flash_cache, - rope_position_ids_cache, - ) + _process_complex_rope(graph, q_match, k_match) gm = canonicalize_graph(gm) return gm -def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]: +def _get_default_unsqueeze_dim(op): + schema = next(iter(op._schemas.values())) + for a in schema.arguments: + if a.name == "unsqueeze_dim" and a.has_default_value: + return a.default_value + raise RuntimeError(f"No default unsqueeze_dim on {op}") + + +def match_rope_layout(gm: GraphModule, expected_layout: str = "bsnd") -> GraphModule: + """ + Match and transform input and output of rope ops to the layout specified to meet requirements of optimized ops. + Supported layout is 'bsnd' (batch, seq, head, dim). + """ + supported = {"bsnd", "bnsd"} + if expected_layout.lower() not in supported: + ad_logger.warning( + f"Unsupported RoPE layout '{expected_layout}'; expected '{supported}'. Skipping RoPE layout matching." + ) + return gm + + ad_logger.info(f"Match RoPE layout to {expected_layout}") + + graph = gm.graph + rope_ops = { + torch.ops.rope.torch_apply_rope_with_explicit_cos_sin, + torch.ops.rope.torch_apply_rope_with_qk_interleaving, + torch.ops.rope.torch_apply_rope_with_complex_freqs, + } + + need_transpose = False + need_canonicalize_graph = False + for node in graph.nodes: + if not is_op(node, rope_ops): + continue + + rope_op = next(op for op in rope_ops if is_op(node, op)) + if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + q_node, k_node, freqs_node, *rest = node.args + unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op) + else: + q_node, k_node, cos_node, sin_node, *rest = node.args + unsq = rest[0] if rest else _get_default_unsqueeze_dim(rope_op) + + if unsq == 2: + current_layout = "bsnd" + elif unsq == 1: + current_layout = "bnsd" + else: + ad_logger.warning( + "Unsqueeze_dim is not one of [1, 2]. " + "Unable to infer layout of q node. Skip layout matching" + ) + continue + + need_transpose = expected_layout.lower() != current_layout + + if not need_transpose: + continue + + need_canonicalize_graph = True + # retrieve q and k output node from node + q_rope_old, k_rope_old = extract_output_tuple(node, 2) + if q_rope_old is None or k_rope_old is None: + ad_logger.warning( + f"Failed to extract all two outputs from the explicit op, \ + get {q_rope_old}, {k_rope_old}, fail to match rope layout with {node} with" + ) + continue + + ad_logger.debug( + f"Inferred RoPE input layout: '{current_layout}']Mapping layout to '{expected_layout}']" + ) + with graph.inserting_before(node): + q_for_op = graph.call_function(torch.ops.aten.transpose, args=(q_node, 1, 2)) + k_for_op = graph.call_function(torch.ops.aten.transpose, args=(k_node, 1, 2)) + q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,)) + k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,)) + + q_for_op_contig.meta["val"] = q_node.meta["val"].transpose(1, 2) + k_for_op_contig.meta["val"] = k_node.meta["val"].transpose(1, 2) + + if is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + new_args = ( + q_for_op_contig, + k_for_op_contig, + freqs_node, + 2 if expected_layout.lower() == "bsnd" else 1, + ) # unsqueeze_dim updated + else: + new_args = ( + q_for_op_contig, + k_for_op_contig, + cos_node, + sin_node, + 2 if expected_layout.lower() == "bsnd" else 1, + ) # unsqueeze_dim updated + node.args = new_args + + with graph.inserting_after(q_rope_old): + q_rope_new = graph.call_function(torch.ops.aten.transpose, args=(q_rope_old, 1, 2)) + with graph.inserting_after(k_rope_old): + k_rope_new = graph.call_function(torch.ops.aten.transpose, args=(k_rope_old, 1, 2)) + + # Preserve fake tensor in meta["val"] for the transposed inputs + q_rope_new.meta["val"] = q_rope_old.meta["val"] + q_rope_old.meta["val"] = q_rope_old.meta["val"].transpose(1, 2) + k_rope_new.meta["val"] = k_rope_old.meta["val"] + k_rope_old.meta["val"] = k_rope_old.meta["val"].transpose(1, 2) + + q_rope_old.replace_all_uses_with(q_rope_new) + k_rope_old.replace_all_uses_with(k_rope_new) + q_rope_new.args = (q_rope_old, 1, 2) + k_rope_new.args = (k_rope_old, 1, 2) + + if need_canonicalize_graph: + gm = canonicalize_graph(gm) + return gm + + +def optimize_rope(gm: GraphModule) -> GraphModule: + """ + Scan the FX graph and replace calls to the torch-reference RoPE ops with + the optimized `rope::flashinfer` kernel. + Precomputes positional IDs and the fused cosine-sine cache as explicit nodes, + and reuses those nodes when possible. + """ + ad_logger.info("RoPE optimization") + graph = gm.graph + rope_flash_cache: DefaultDict[Any, Optional[Node]] = defaultdict(lambda: None) + rope_position_ids_cache: Dict[str, Node] = {} + + for node in list(graph.nodes): + if is_op(node, torch.ops.rope.torch_apply_rope_with_explicit_cos_sin): + _optimize_explicit(graph, node, rope_flash_cache, rope_position_ids_cache) + elif is_op(node, torch.ops.rope.torch_apply_rope_with_complex_freqs): + _optimize_complex(graph, node, rope_flash_cache, rope_position_ids_cache) + + gm = canonicalize_graph(gm) + return gm + + +def _optimize_explicit( + graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node] +) -> None: + # node.args may be (q, k, cos, sin) or (q, k, cos, sin, unsq) + q_node, k_node, cos_node, sin_node, *rest = node.args + # retrieve q and k output node from node + q_rope_old, k_rope_old = extract_output_tuple(node, 2) + if q_rope_old is None or k_rope_old is None: + ad_logger.warning( + f"Failed to extract all two outputs from the explicit op, \ + get {q_rope_old}, {k_rope_old}, fail to replace {node} with flashinfer rope" + ) + return + + # Sanity check on head_dim + if not _validate_rope_inputs(q_node, k_node): + return + + # Sanity check that input layout is BSND (no transpose needed). + q_fake = q_node.meta.get("val", None) + if q_fake is not None and len(q_fake.shape) > 2: + if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)): + ad_logger.warning( + f"""Sanity check failed: q_fake should have shape [b, s, n, d], + s should be symbolic and n should be int, instead got shape {q_fake.shape}""" + ) + return + elif q_fake is not None: + ad_logger.warning( + f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}" + ) + return + + head_dim = cos_node.meta["val"].shape[-1] + half_head_dim = head_dim // 2 + + cache_key = (cos_node, sin_node) + if cache_key in cache: + fused_cos_sin_to = cache[cache_key] + else: + with graph.inserting_after(cos_node): + cos_prefix = graph.call_function( + torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim) + ) + with graph.inserting_after(sin_node): + sin_prefix = graph.call_function( + torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim) + ) + with graph.inserting_after(sin_prefix): + fused_cos_sin = graph.call_function( + torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1) + ) + with graph.inserting_after(q_node): + sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0)) + sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1)) + with graph.inserting_after(_get_last_node([sym_batch, sym_seq])): + bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq)) + with graph.inserting_after(_get_last_node([bs_seq, fused_cos_sin])): + fused_cos_sin_flat = graph.call_function( + torch.ops.aten.view, args=(fused_cos_sin, (bs_seq, -1)) + ) + with graph.inserting_after(fused_cos_sin_flat): + fused_cos_sin_to = graph.call_function( + torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32) + ) + cache[cache_key] = fused_cos_sin_to + + with graph.inserting_before(node): + position_ids = _get_position_ids( + graph, + q_node, + batch_dim=0, + seq_dim=1, + rope_position_ids_cache=pos_cache, + ) + flash_node = graph.call_function( + torch.ops.rope.flashinfer, + args=(q_node, k_node, position_ids, fused_cos_sin_to, True), + ) + + with graph.inserting_after(flash_node): + q_rope_new = graph.call_function(operator.getitem, args=(flash_node, 0)) + k_rope_new = graph.call_function(operator.getitem, args=(flash_node, 1)) + + q_rope_new.meta["val"] = q_rope_old.meta.get("val", None) + k_rope_new.meta["val"] = k_rope_old.meta.get("val", None) + + q_rope_old.replace_all_uses_with(q_rope_new) + k_rope_old.replace_all_uses_with(k_rope_new) + + graph.erase_node(q_rope_old) + graph.erase_node(k_rope_old) + + +def _optimize_complex( + graph: GraphModule, node: Node, cache: Dict[Any, Node], pos_cache: Dict[str, Node] +) -> None: + q_node, k_node, inv_freq_node = node.args + + # Sanity check on head_dim + if not _validate_rope_inputs(q_node, k_node): + return + + # Sanity check that input layout is BSND (no transpose needed). + q_fake = q_node.meta.get("val", None) + if q_fake is not None and len(q_fake.shape) > 2: + if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)): + ad_logger.warning( + f"""Sanity check failed: q_fake should have shape [b, s, n, d], + s should be symbolic and n should be int, instead got shape {q_fake.shape}""" + ) + return + elif q_fake is not None: + ad_logger.warning( + f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}" + ) + return + + # Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash + if inv_freq_node in cache: + cos_sin_flash = cache[inv_freq_node] + else: + # Compute the fused cosine/sine cache. + with graph.inserting_after(inv_freq_node): + real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,)) + imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,)) + with graph.inserting_after(real_part): + cos_sin_flash_3d = graph.call_function( + torch.ops.aten.cat, args=((real_part, imag_part), -1) + ) + with graph.inserting_after(q_node): + sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 0)) + sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, 1)) + with graph.inserting_after(_get_last_node([sym_batch, sym_seq])): + bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq)) + with graph.inserting_after(_get_last_node([bs_seq, cos_sin_flash_3d])): + fused_cos_sin_flat = graph.call_function( + torch.ops.aten.view, args=(cos_sin_flash_3d, (bs_seq, -1)) + ) + with graph.inserting_after(fused_cos_sin_flat): + cos_sin_flash = graph.call_function( + torch.ops.aten.to, args=(fused_cos_sin_flat, torch.float32) + ) + cache[inv_freq_node] = cos_sin_flash + + with graph.inserting_before(node): + position_ids = _get_position_ids( + graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=pos_cache + ) + flash_node = graph.call_function( + torch.ops.rope.flashinfer, + args=(q_node, k_node, position_ids, cos_sin_flash, False), + ) + + flash_node.meta["val"] = node.meta.get("val", None) + node.replace_all_uses_with(flash_node) + graph.erase_node(node) + + +def _match_input_interleave_pattern(node: Node) -> Optional[Dict[str, Node]]: + """ + Detect DeepSeek-style interleave on Q/K: + reshape(transpose(view(raw, [b,h,s,d//2,2]), 4, 3), [b,h,s,d]) + Returns: + {"interleaved": raw_node} if matched, else None. + """ + if not is_op(node, torch.ops.aten.reshape): + return None + transpose_node = node.args[0] + if not is_op(transpose_node, torch.ops.aten.transpose): + return None + view_node = transpose_node.args[0] + if not is_op(view_node, torch.ops.aten.view): + return None + raw_node = view_node.args[0] + if not isinstance(raw_node, Node): + return None + return {"interleaved": raw_node} + + +def _match_explicit_rope_subpattern(add_node: Node) -> Optional[Dict[str, Node]]: """ Given an aten.add.Tensor node that is expected to compute: output = (raw_input * unsqueeze(cos)) + (rotate_half(raw_input) * unsqueeze(sin)) @@ -195,7 +566,7 @@ def _match_rotary_subpattern_V1(add_node: Node) -> Optional[Dict[str, Node]]: } -def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]]: +def _match_complex_rope_subpattern(type_as_node: Node) -> Optional[Dict[str, Node]]: """ Given a type_as node, this function inspects the graph structure and returns a dictionary with: @@ -245,9 +616,10 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]] else: return None - # Verify that the unsqueeze is performed along dimension 2. - if not (len(unsqueeze_node.args) >= 2 and unsqueeze_node.args[1] == 2): + if not (len(unsqueeze_node.args) >= 2): return None + unsqueeze_dim = unsqueeze_node.args[1] + inv_freq_candidate = unsqueeze_node.args[0] # Match the view_as_complex branch. @@ -273,27 +645,27 @@ def _match_rotary_subpattern_V2(type_as_node: Node) -> Optional[Dict[str, Node]] "input": input_tensor, "inv_freq": inv_freq_candidate, "out": type_as_node, + "unsqueeze_dim": unsqueeze_dim, } -def _process_rope_v1( - graph: GraphModule, +def _process_explicit_rope( + graph: GraphModule.graph, q_match: Dict[str, Node], k_match: Dict[str, Node], start_boundary: Node, - rope_flash_cache: DefaultDict[Any, Optional[Node]], - rope_position_ids_cache: Dict[str, Node], ) -> None: """ - Process a region that matched the legacy RoPE pattern (v1). - Inserts the custom op (flashinfer) and replaces the original add nodes. - Precomputes positional IDs and the fused cosine-sine cache as explicit nodes, - and reuses those nodes when possible. + Replace matched Explicit RoPE subgraph with `rope::torch_apply_rope_with_explicit_cos_sin`. """ q_node = q_match["raw_input"] k_node = k_match["raw_input"] - cos_node = q_match["unsqueeze_cos"].args[0] - sin_node = q_match["unsqueeze_sin"].args[0] + cos_unsq = q_match["unsqueeze_cos"] + sin_unsq = q_match["unsqueeze_sin"] + cos_node = cos_unsq.args[0] + sin_node = sin_unsq.args[0] + unsq_dim = cos_unsq.args[1] + add_node = q_match["add_node"] # Sanity-check: ensure cos/sin nodes trace back to aten.cos/aten.sin. bfs( @@ -309,167 +681,197 @@ def _process_rope_v1( boundary=start_boundary, ) - # Infer input layout; default to [b, n, s, d] if inference fails. - q_fake = q_node.meta.get("val", None) - if q_fake is not None and len(q_fake.shape) > 2: - need_transpose = isinstance(q_fake.shape[1], int) - ad_logger.debug( - f"Inferred RoPE input layout: [{'[b, n, s, d]' if need_transpose else '[b, s, n, d]'}]" - ) - # Additional sanity check for the third dimension - if need_transpose: - if not isinstance(q_fake.shape[2], torch.SymInt): - ad_logger.warning( - "Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]" - ) - need_transpose = True - else: - if not isinstance(q_fake.shape[1], torch.SymInt): - ad_logger.warning( - "Sanity check failed: q_fake.shape[2] should be symbolic. Defaulting to [b, n, s, d]" - ) - need_transpose = True - else: - ad_logger.warning("Unable to infer layout of q node. Defaulting to [b, n, s, d].") - need_transpose = True - - with graph.inserting_before(q_match["add_node"]): - if need_transpose: - q_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(q_node, 1, 2)) - k_for_op = graph.call_function(torch.ops.aten.transpose.int, args=(k_node, 1, 2)) - q_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(q_for_op,)) - k_for_op_contig = graph.call_function(torch.ops.aten.contiguous, args=(k_for_op,)) - else: - q_for_op_contig, k_for_op_contig = q_node, k_node - - head_dim = cos_node.meta["val"].shape[-1] - half_head_dim = head_dim // 2 - - cache = rope_flash_cache - cache_key = (cos_node, sin_node) - if cache_key in cache: - fused_cos_sin = cache[cache_key] - else: - cos_prefix = graph.call_function( - torch.ops.aten.slice, args=(cos_node, -1, 0, half_head_dim) - ) - sin_prefix = graph.call_function( - torch.ops.aten.slice, args=(sin_node, -1, 0, half_head_dim) - ) - fused_cos_sin = graph.call_function( - torch.ops.aten.cat, args=((cos_prefix, sin_prefix), -1) - ) - fused_cos_sin = graph.call_function(operator.getitem, args=(fused_cos_sin, 0)) - fused_cos_sin = graph.call_function( - torch.ops.aten.to.dtype, args=(fused_cos_sin, torch.float32) - ) - cache[cache_key] = fused_cos_sin - - position_ids = _get_position_ids( - graph, - q_for_op_contig, - batch_dim=0, - seq_dim=1, - rope_position_ids_cache=rope_position_ids_cache, + with graph.inserting_before(add_node): + rope_node = graph.call_function( + torch.ops.rope.torch_apply_rope_with_explicit_cos_sin, + args=(q_node, k_node, cos_node, sin_node, unsq_dim), ) - flash_node = graph.call_function( - torch.ops.rope.flashinfer, - args=(q_for_op_contig, k_for_op_contig, position_ids, fused_cos_sin, True), + with graph.inserting_after(rope_node): + out_q = graph.call_function(operator.getitem, args=(rope_node, 0)) + out_k = graph.call_function(operator.getitem, args=(rope_node, 1)) + + out_q.meta["val"] = add_node.meta.get("val", None) + out_k.meta["val"] = k_match["add_node"].meta.get("val", None) + + q_match["add_node"].replace_all_uses_with(out_q) + k_match["add_node"].replace_all_uses_with(out_k) + + +def _process_complex_rope( + graph: GraphModule.graph, + q_match: Dict[str, Node], + k_match: Dict[str, Node], +) -> None: + """ + Replace matched Complex RoPE subgraph with `rope::torch_apply_rope_with_complex_freqs`. + """ + xq = q_match["input"] + xk = k_match["input"] + inv = q_match["inv_freq"] + usdim = q_match["unsqueeze_dim"] + out_node = q_match.get("out") + + if inv != k_match["inv_freq"]: + ad_logger.warning( + "Mismatch of freqs_cis (inv_freq) between branches. Fail to match complex rope pattern" + ) + return + + with graph.inserting_before(out_node): + rope_node = graph.call_function( + torch.ops.rope.torch_apply_rope_with_complex_freqs, + args=(xq, xk, inv, usdim), ) - with graph.inserting_after(flash_node): - raw_q = graph.call_function(operator.getitem, args=(flash_node, 0)) - raw_k = graph.call_function(operator.getitem, args=(flash_node, 1)) + with graph.inserting_after(rope_node): + out_q = graph.call_function(operator.getitem, args=(rope_node, 0)) + out_k = graph.call_function(operator.getitem, args=(rope_node, 1)) - if need_transpose: - with graph.inserting_after(raw_q): - new_q = graph.call_function(torch.ops.aten.transpose.int, args=(raw_q, 1, 2)) - with graph.inserting_after(raw_k): - new_k = graph.call_function(torch.ops.aten.transpose.int, args=(raw_k, 1, 2)) - else: - new_q, new_k = raw_q, raw_k + out_q.meta["val"] = out_node.meta.get("val", None) + out_k.meta["val"] = k_match["out"].meta.get("val", None) - new_q.meta["val"] = q_match["add_node"].meta.get("val", None) - new_k.meta["val"] = k_match["add_node"].meta.get("val", None) - - q_match["add_node"].replace_all_uses_with(new_q) - k_match["add_node"].replace_all_uses_with(new_k) + out_node.replace_all_uses_with(out_q) + k_match["out"].replace_all_uses_with(out_k) -def _process_rope_v2( +def _process_input_interleave_rope( graph: GraphModule, q_match: Dict[str, Node], k_match: Dict[str, Node], - rope_flash_cache: DefaultDict[Any, Optional[Node]], - rope_position_ids_cache: Dict[str, Node], ) -> None: """ - Process a region that matched the complex-multiplication RoPE pattern (v2). - Inserts the custom op (flashinfer) after extracting frequency information, - and replaces the original type_as nodes. - Precomputes positional IDs and the fused cosine-sine cache as explicit nodes, - and reuses those nodes when possible. + Replace a matched DS-style RoPE subgraph with a call to rope::torch_apply_rope_with_qk_interleaving. + Cache the one-time unsqueeze of cos/sin. """ - q_node = q_match["input"] - k_node = k_match["input"] - inv_freq_node = q_match["inv_freq"] + q_node = q_match["raw_input"] + k_node = k_match["raw_input"] + cos_node = q_match["unsqueeze_cos"].args[0] + sin_node = q_match["unsqueeze_sin"].args[0] + # A patch for the case when q_output appears before k_input in the graph + # Move q_output down right before its first user so that graph remains in + # topological order after inserting the apply rope custom op + q_match["add_node"] = _move_node_before_first_user(q_match["add_node"]) - if inv_freq_node != k_match["inv_freq"]: - raise RuntimeError("Mismatch of freqs_cis (inv_freq) between branches.") + # Infer unsqueeze_dim from layout + unsq_dim = 1 + fake = q_node.meta.get("val", None) + if fake is not None and len(fake.shape) == 4: + # if shape[1] symbolic, it's [B, S, N, D] => BSND -> head dim is 2 + if isinstance(fake.shape[1], torch.SymInt): + unsq_dim = 2 + else: + unsq_dim = 1 - # Sanity check that input layout is BSND (no transpose needed). - q_fake = q_node.meta.get("val", None) - if q_fake is not None and len(q_fake.shape) > 2: - if not (isinstance(q_fake.shape[1], torch.SymInt) and isinstance(q_fake.shape[2], int)): + with graph.inserting_after(_get_last_node([q_node, k_node, cos_node, sin_node])): + ds_node = graph.call_function( + torch.ops.rope.torch_apply_rope_with_qk_interleaving, + args=(q_node, k_node, cos_node, sin_node, unsq_dim), + ) + + with graph.inserting_after(ds_node): + q_out = graph.call_function(operator.getitem, args=(ds_node, 0)) + k_out = graph.call_function(operator.getitem, args=(ds_node, 1)) + + q_out.meta["val"] = q_match["add_node"].meta.get("val", None) + k_out.meta["val"] = k_match["add_node"].meta.get("val", None) + + q_match["add_node"].replace_all_uses_with(q_out) + k_match["add_node"].replace_all_uses_with(k_out) + + +def _move_node_before_first_user(node: Node) -> Node: + """ + Remove `node` from the graph and re-insert a clone of it immediately + before its earliest user. Returns the new node. + + If `node` has no users, or is already right before its first user, + this is a no-op and returns the original node. + """ + graph = node.graph + ordering = list(graph.nodes) + + users = list(node.users) + if not users: + return node + + # locate the earliest user in the current ordering + first_user = min(users, key=lambda u: ordering.index(u)) + if ordering.index(node) == ordering.index(first_user) - 1: + return node + + with graph.inserting_before(first_user): + new_node = graph.node_copy(node, lambda n: n) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + return new_node + + +def _get_last_node(nodes: Sequence[Node]) -> Node: + """ + Given a list of FX Nodes, + return the one that appears last in the graph's execution order. + """ + if not nodes: + raise ValueError("`nodes` must be a non-empty sequence of FX Node objects") + + graph = nodes[0].graph + ordering = list(graph.nodes) + + # Sanity check that all nodes are in same graph + valid = [n for n in nodes if n in ordering] + if not valid: + raise ValueError("None of the provided nodes belong to the same graph") + + last = max(valid, key=lambda n: ordering.index(n)) + return last + + +def _validate_rope_inputs(q_node: Node, k_node: Node) -> bool: + """ + Validates that: + - The last dimension (head_dim) of both q and k is a multiple of 64. + - The dtype of q and k is half precision (bfloat16 or float16). + - Layout should be [B,S,N,D] (dim 1 should be symbolic) + """ + for name, node in [("q", q_node), ("k", k_node)]: + fake_val = node.meta.get("val", None) + if fake_val is None: ad_logger.warning( - f"""Sanity check failed: q_fake should have shape [b, s, n, d], - s should be symbolic and n should be int, instead got shape {q_fake.shape}""" + f"Meta['val'] for {name} not available; skipping RoPE transformation." ) - else: - ad_logger.warning( - f"Sanity check failed: q_fake should be 3D or 4D, but got shape {q_fake.shape}" - ) + return False - # Retrieve or register the lookup table for inv_freq_node -> cos_sin_flash - cache = rope_flash_cache - if inv_freq_node in cache: - cos_sin_flash = cache[inv_freq_node] - else: - # Compute the fused cosine/sine cache. - with graph.inserting_after(inv_freq_node): - real_part = graph.call_function(torch.ops.aten.real, args=(inv_freq_node,)) - imag_part = graph.call_function(torch.ops.aten.imag, args=(inv_freq_node,)) - with graph.inserting_after(real_part): - cos_sin_flash_3d = graph.call_function( - torch.ops.aten.cat, args=((real_part, imag_part), -1) + # Check dtype + if fake_val.dtype not in (torch.float16, torch.bfloat16): + ad_logger.warning( + f"""{name} tensor is {fake_val.dtype}, + expected half precision (float16 or bfloat16). Skipping RoPE transformation.""" ) - with graph.inserting_after(cos_sin_flash_3d): - cos_sin_flash = graph.call_function(operator.getitem, args=(cos_sin_flash_3d, 0)) - with graph.inserting_after(cos_sin_flash): - cos_sin_flash = graph.call_function( - torch.ops.aten.to.dtype, args=(cos_sin_flash, torch.float32) + return False + + # Check head_dim + if len(fake_val.shape) < 1: + ad_logger.warning(f"{name} tensor has invalid shape {fake_val.shape}.") + return False + head_dim = fake_val.shape[-1] + if isinstance(head_dim, int) and head_dim % 64 != 0: + ad_logger.warning( + f"{name} head_dim = {head_dim} is not a multiple of 64. Skipping RoPE transformation." ) - cache[inv_freq_node] = cos_sin_flash + return False - with graph.inserting_before(q_match["out"]): - position_ids = _get_position_ids( - graph, q_node, batch_dim=0, seq_dim=1, rope_position_ids_cache=rope_position_ids_cache - ) - flash_node = graph.call_function( - torch.ops.rope.flashinfer, - args=(q_node, k_node, position_ids, cos_sin_flash, False), - ) + # Check shape + if not isinstance(fake_val.shape[1], torch.SymInt): + ad_logger.warning( + f"{name} has shape {fake_val.shape} that is not supported. Only support [B, S, N, D] layout.\ + Skipping RoPE transformation." + ) + return False - with graph.inserting_after(flash_node): - raw_q = graph.call_function(operator.getitem, args=(flash_node, 0)) - raw_k = graph.call_function(operator.getitem, args=(flash_node, 1)) - - raw_q.meta["val"] = q_match["out"].meta.get("val", None) - raw_k.meta["val"] = k_match["out"].meta.get("val", None) - - q_match["out"].replace_all_uses_with(raw_q) - k_match["out"].replace_all_uses_with(raw_k) + return True def _get_position_ids( @@ -491,22 +893,17 @@ def _get_position_ids( sym_batch = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, batch_dim)) sym_seq = graph.call_function(torch.ops.aten.sym_size.int, args=(q_node, seq_dim)) + bs_seq = graph.call_function(operator.mul, args=(sym_batch, sym_seq)) # Retrieve device information, ensuring it is a torch.device. device = q_node.meta.get("device", "cpu") if isinstance(device, str): device = torch.device(device) - # Build positions: arange(sym_seq) -> view -> expand -> flatten. - positions_node = graph.call_function( + position_ids = graph.call_function( torch.ops.aten.arange, - args=(sym_seq,), + args=(bs_seq,), kwargs={"dtype": torch.float32, "device": device, "pin_memory": False}, ) - positions_node = graph.call_function(torch.ops.aten.view, args=(positions_node, (1, -1))) - positions_node = graph.call_function( - torch.ops.aten.expand, args=(positions_node, (sym_batch, -1)) - ) - position_ids = graph.call_function(torch.ops.aten.flatten, args=(positions_node,)) rope_position_ids_cache["position_ids"] = position_ids return position_ids diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index a9bdd7402f..fb91b828d7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -25,12 +25,14 @@ from .library import ( insert_cached_attention, match_attention_layout, match_causal_attn_mask, + match_complex_rope, match_eager_attention, + match_explicit_rope, match_grouped_attention, match_moe_pattern, match_repeat_kv, - match_rope_v1, - match_rope_v2, + match_rope_layout, + optimize_rope, quantize, resize_kv_cache, ) @@ -122,10 +124,10 @@ class InferenceOptimizer: egm = match_attention_layout(egm, self.attention_op) # Match rope - # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved - # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 - egm = match_rope_v1(egm) - egm = match_rope_v2(egm) + egm = match_explicit_rope(egm) + egm = match_complex_rope(egm) + # Match RoPE layout expected by our backend + egm = match_rope_layout(egm, self.attention_op.get_attention_layout()) ############################################################################################ # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION @@ -134,6 +136,10 @@ class InferenceOptimizer: # eliminate redundant transpose operations egm = eliminate_redundant_transposes(egm) + # TODO (lucaslie): let's move this to perf optimization once TP sharding is improved + # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 + egm = optimize_rope(egm) + # run TP sharding across ranks egm = column_row_shard(egm, local_rank, world_size) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index c884cf84c5..9d3f630fec 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -1,5 +1,6 @@ """Common utils for torch fx graph transformation.""" +import operator from dataclasses import dataclass from typing import Callable, Iterable, List, Optional, Tuple, Union @@ -320,3 +321,22 @@ def bfs( visited.add(next_node) queue.append(next_node) raise RuntimeError(f"Could not find node with target condition {target}.") + + +def extract_output_tuple(node: Node, count: int = 2): + """ + Extract up to `count` outputs from a tuple-producing node. + Returns a list of length `count`, with None if an output isn't found. + """ + results = [] + for idx in range(count): + user_node = next( + ( + u + for u in node.users + if u.op == "call_function" and u.target == operator.getitem and u.args[1] == idx + ), + None, + ) + results.append(user_node) + return results diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index 00cd095a3a..7422ddafac 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -51,7 +51,7 @@ def run_test( num_params_gm = count_parameters(gm) assert num_params_model == num_params_gm - assert all_close(y_model, y_gm, atol=atol, rtol=rtol) + torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol) # graph transformation + check gm_transformed = transform(gm, *args) @@ -71,9 +71,7 @@ def run_test( if strict_loading: # check if output equals without loading state dict - assert all_close(y_model, y_transformed, atol=atol, rtol=rtol), ( - f"{y_model=}, {y_transformed=}" - ) + torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol) if test_load_hook: # check if loading hook works from original state dict @@ -83,9 +81,7 @@ def run_test( gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False) y_loaded_from_original = gm_transformed(x) - assert all_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol), ( - f"{y_model=}, {y_loaded_from_original=}" - ) + torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol) # check if loading hook works from state_dict of a transformed model state_dict_sharded = copy.deepcopy(gm_transformed.state_dict()) @@ -95,9 +91,7 @@ def run_test( gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False) y_loaded_from_transformed = gm_transformed(x) - assert all_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol), ( - f"{y_model=}, {y_loaded_from_transformed=}" - ) + torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol) # check if we can still export the model as expected torch_export(gm_transformed, args=(x,)) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 3ce446c511..ab2713805e 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -222,3 +222,55 @@ def _hf_model_dir_or_hub_id( return hf_model_dir else: return hf_hub_id + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb_explicit( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 +): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_complex( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype. + unsqueeze_dim: int = 2, +): + # Reshape the inputs to pair the last dimension. + xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + # Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim. + freqs = freqs_cis.unsqueeze(unsqueeze_dim) + xq_out = torch.view_as_real(xq_complex * freqs).flatten(3) + xk_out = torch.view_as_real(xk_complex * freqs).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L339 +def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """ + Apply rotary positional embeddings by interleaving Q/K , + indexing cos/sin tables with position_ids, and returning rotated q, k. + cos: [seq_len, head_dim] + sin: [seq_len, head_dim] + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_hf_flashinfer_rope.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_hf_flashinfer_rope.py deleted file mode 100644 index 495ef6a9ff..0000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_hf_flashinfer_rope.py +++ /dev/null @@ -1,145 +0,0 @@ -from typing import Tuple - -import flashinfer -import pytest -import torch - -import tensorrt_llm._torch.auto_deploy # noqa: F401 - -torch.manual_seed(0) - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 -): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64 -@pytest.mark.parametrize( - "dtype,atol,rtol", - [ - (torch.bfloat16, 1e-4, 1e-4), - (torch.float16, 5e-4, 5e-4), - ], - ids=["bfloat16", "float16"], # q/k must be in half precision -) -def test_flashinfer_and_custom_rope_ops(dtype, atol, rtol, head_dim): - device = "cuda" - batch = 2 - seq_len = 4 - n_head = 3 - - # Prepare rotary embedding values. - inv_freq = 1.0 / ( - 10000 - ** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2)) - ) - positions_range = torch.arange(seq_len, dtype=torch.float32, device=device) - angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2] - cos_vals = torch.cos(angles) # [seq_len, head_dim//2] - sin_vals = torch.sin(angles) # [seq_len, head_dim//2] - - # For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated). - cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1) - # For HF and the custom op: duplicated layout [seq_len, head_dim]. - cos_new = torch.cat([cos_vals, cos_vals], dim=-1) - sin_new = torch.cat([sin_vals, sin_vals], dim=-1) - - query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) - key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) - - # Direct FlashInfer kernel call. - query_flat = query.view(batch * seq_len, n_head * head_dim) - key_flat = key.view(batch * seq_len, n_head * head_dim) - positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)]) - q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache( - positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True - ) - q_flash = q_flash.view(batch, seq_len, n_head, head_dim) - k_flash = k_flash.view(batch, seq_len, n_head, head_dim) - - # HF implementation using apply_rotary_pos_emb. - # HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1 - q_for_hf = query.transpose(1, 2).clone() - k_for_hf = key.transpose(1, 2).clone() - cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] - sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] - q_hf, k_hf = apply_rotary_pos_emb(q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1) - - # Convert outputs to [batch, seq_len, n_head, head_dim] - q_hf = q_hf.transpose(1, 2).to(dtype) - k_hf = k_hf.transpose(1, 2).to(dtype) - - # Custom op call - custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, True) - - torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol) - torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol) - torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol) - torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol) - - -# Version 2: complex multiplication approach -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64 -@pytest.mark.parametrize( - "dtype,atol,rtol", - [ - (torch.bfloat16, 1e-5, 1e-5), - (torch.float16, 5e-4, 5e-4), - ], - ids=["bfloat16", "float16"], # q/k must be in half precision -) -def test_flashinfer_complex_rotary(dtype, atol, rtol, head_dim): - device = "cuda" - batch = 2 - seq_len = 4 - n_head = 3 - - inv_freq = 1.0 / ( - 10000 - ** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2)) - ) - positions_range = torch.arange(seq_len, dtype=torch.float32, device=device) - angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2) - freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles) - freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2) - - query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) - key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) - - out_q_v2, out_k_v2 = apply_rotary_emb(query, key, freqs_cis) - - cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2) - sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2) - cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim)) - - # q/k of llama4 rope is interleaved - positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)]) - custom_q, custom_k = torch.ops.rope.flashinfer(query, key, positions, cos_sin_cache, False) - - torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol) - torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py new file mode 100644 index 0000000000..ee7bca3f6b --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_rope_op_variants.py @@ -0,0 +1,294 @@ +import flashinfer +import pytest +import torch +from _model_test_utils import ( + apply_rotary_pos_emb_complex, + apply_rotary_pos_emb_ds, + apply_rotary_pos_emb_explicit, +) + +import tensorrt_llm._torch.auto_deploy # noqa: F401 + +torch.manual_seed(0) + + +@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64 +@pytest.mark.parametrize( + "dtype,atol,rtol", + [ + (torch.bfloat16, 1e-4, 1e-4), + (torch.float16, 5e-4, 5e-4), + ], + ids=["bfloat16", "float16"], # q/k must be in half precision +) +def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim): + """ + Verify FlashInfer's Neox RoPE kernel against HF's apply_rotary_pos_emb: + - Q/K: [B, S, N, D] non-interleaved half-precision. + - cos_sin_cache: [S, D] = [cos||sin] concatenated. + - HF path: Q/K → [B, N, S, D], cos_new/sin_new: [S, D] duplicated, then broadcast to [B, S, D]. + """ + device = "cuda" + batch = 2 + seq_len = 4 + n_head = 3 + + # Prepare rotary embedding values. + inv_freq = 1.0 / ( + 10000 + ** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2)) + ) + positions_range = torch.arange(seq_len, dtype=torch.float32, device=device) + angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2] + cos_vals = torch.cos(angles) # [seq_len, head_dim//2] + sin_vals = torch.sin(angles) # [seq_len, head_dim//2] + + # For direct FlashInfer call: non-interleaved cache [seq_len, head_dim] (concatenated). + cos_sin_cache = torch.cat([cos_vals, sin_vals], dim=1) + cos_sin_cache_expand = ( + cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1) + ) # [batch * seq_len, head_dim] + # For HF and the custom op: duplicated layout [seq_len, head_dim]. + cos_new = torch.cat([cos_vals, cos_vals], dim=-1) + sin_new = torch.cat([sin_vals, sin_vals], dim=-1) + + query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + + # Direct FlashInfer kernel call. + query_flat = query.view(batch * seq_len, n_head * head_dim) + key_flat = key.view(batch * seq_len, n_head * head_dim) + positions = torch.cat([torch.arange(seq_len, device=device) for _ in range(batch)]) + q_flash, k_flash = flashinfer.rope.apply_rope_with_cos_sin_cache( + positions, query_flat, key_flat, head_dim, cos_sin_cache, is_neox=True + ) + q_flash = q_flash.view(batch, seq_len, n_head, head_dim) + k_flash = k_flash.view(batch, seq_len, n_head, head_dim) + + # HF implementation using apply_rotary_pos_emb. + # HF expects [batch, n_head, seq_len, head_dim] for unsqueeze_dim=1 + q_for_hf = query.transpose(1, 2).clone() + k_for_hf = key.transpose(1, 2).clone() + cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] + sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] + q_hf, k_hf = apply_rotary_pos_emb_explicit( + q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1 + ) + + # Convert outputs to [batch, seq_len, n_head, head_dim] + q_hf = q_hf.transpose(1, 2).to(dtype) + k_hf = k_hf.transpose(1, 2).to(dtype) + + # Custom op call + positions_flat = torch.arange(batch * seq_len, device=device) + custom_q, custom_k = torch.ops.rope.flashinfer( + query, key, positions_flat, cos_sin_cache_expand, True + ) + + torch.testing.assert_close(q_hf, q_flash, rtol=rtol, atol=atol) + torch.testing.assert_close(k_hf, k_flash, rtol=rtol, atol=atol) + torch.testing.assert_close(q_hf, custom_q, rtol=rtol, atol=atol) + torch.testing.assert_close(k_hf, custom_k, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("head_dim", [64, 256]) # Must be a multiple of 64 +@pytest.mark.parametrize( + "dtype,atol,rtol", + [ + (torch.bfloat16, 1e-5, 1e-5), + (torch.float16, 5e-4, 5e-4), + ], + ids=["bfloat16", "float16"], # q/k must be in half precision +) +def test_flashinfer_custom_op_and_complex_impl(dtype, atol, rtol, head_dim): + """ + Check FlashInfer's RoPE matches the complex-multiplication approach: + - Q/K: [B, S, N, D] non-interleaved half-precision. + - freqs_cis: [B, S, D/2] complex polar values. + - flashinfer uses cos_sin_cache: [S, D] interleaved from real/imag of freqs_cis. + """ + device = "cuda" + batch = 2 + seq_len = 4 + n_head = 3 + + inv_freq = 1.0 / ( + 10000 + ** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2)) + ) + positions_range = torch.arange(seq_len, dtype=torch.float32, device=device) + angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # shape: (seq_len, head_dim//2) + freqs_cis = torch.polar(torch.ones((seq_len, head_dim // 2), device=device), angles) + freqs_cis = freqs_cis.unsqueeze(0).expand(batch, -1, -1) # shape: (B, seq, head_dim//2) + + query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + + out_q_v2, out_k_v2 = apply_rotary_pos_emb_complex(query, key, freqs_cis) + + cos_from_freqs = torch.real(freqs_cis) # (B, seq, head_dim//2) + sin_from_freqs = torch.imag(freqs_cis) # (B, seq, head_dim//2) + cos_sin_cache = torch.cat([cos_from_freqs, sin_from_freqs], dim=-1)[0] # (seq, head_dim)) + cos_sin_cache_expand = ( + cos_sin_cache.unsqueeze(0).expand(batch, -1, -1).contiguous().view(batch * seq_len, -1) + ) # [batch * seq_len, head_dim] + + # q/k of llama4 rope is interleaved + positions_flat = torch.arange(batch * seq_len, device=device) + custom_q, custom_k = torch.ops.rope.flashinfer( + query, key, positions_flat, cos_sin_cache_expand, False + ) + + torch.testing.assert_close(out_q_v2, custom_q, rtol=rtol, atol=atol) + torch.testing.assert_close(out_k_v2, custom_k, rtol=rtol, atol=atol) + + +# Copy of TritonWithFlattenedInputs._precompute_freqs_cis +def precompute_freqs_cis_interleaved( + seq_len: int, head_dim: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """ + Precompute interleaved cosine and sine frequency cache for rotary position embeddings (RoPE). + + Returns a tensor of shape [seq_len, head_dim//2, 2], where the last dimension + alternates [cos, sin] values for each rotary frequency. + cache[s, i, 0] == cos(position=s · inv_freq[i]) + cache[s, i, 1] == sin(position=s · inv_freq[i]). + """ + inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)) + t = torch.arange(seq_len, device=device) + angles = t.unsqueeze(1) * inv_freq.unsqueeze(0) + freqs_cis = torch.polar(torch.ones_like(angles), angles) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype) + + +@pytest.mark.parametrize("layout", ["bsnd", "bnsd"]) +@pytest.mark.parametrize("head_dim", [64, 256]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.bfloat16, 1e-4, 1e-4), + (torch.float16, 5e-4, 5e-4), + ], + ids=["bfloat16", "float16"], +) +def test_triton_custom_op_and_hf_impl(layout, head_dim, dtype, atol, rtol): + """ + Validate custom Triton apply_rope_with_input_pos against HF's apply_rotary_pos_emb: + - Q/K: layout 'bsnd'→[B,S,N,D] or 'bnsd'→[B,N,S,D], non-interleaved half-precision. + - cosin_cache: [S, D/2, 2] interleaved [cos,sin]. + - HF path: cos_full/sin_full: [S, D] then expanded to [B, S, D]. + """ + device = "cuda" + batch, seq_len, n_head = 2, 4, 3 + + # build cache and per-batch zero positions + cosin_cache = precompute_freqs_cis_interleaved(seq_len, head_dim, dtype, device) # [S, D/2, 2] + positions = torch.zeros(batch, dtype=torch.int32, device=device) + + if layout == "bsnd": + q = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + k = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + unsq = 2 + else: # "bnsd" + q = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device) + k = torch.randn(batch, n_head, seq_len, head_dim, dtype=dtype, device=device) + unsq = 1 + + # build HF float32 cos/sin full tensors + cos_f32 = cosin_cache[..., 0].to(torch.float32) # [S, H/2] + sin_f32 = cosin_cache[..., 1].to(torch.float32) # [S, H/2] + cos_full = torch.cat([cos_f32, cos_f32], dim=1) # [S, H] + sin_full = torch.cat([sin_f32, sin_f32], dim=1) # [S, H] + cos_exp = cos_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H] + sin_exp = sin_full.unsqueeze(0).expand(batch, -1, -1) # [B, S, H] + + # HF reference in float32, then cast back + q_f32, k_f32 = apply_rotary_pos_emb_explicit( + q.to(torch.float32), k.to(torch.float32), cos_exp, sin_exp, unsqueeze_dim=unsq + ) + q_hf = q_f32.to(dtype) + k_hf = k_f32.to(dtype) + + q_out = torch.ops.rope.apply_rope_with_input_pos(q, cosin_cache, positions, layout) + k_out = torch.ops.rope.apply_rope_with_input_pos(k, cosin_cache, positions, layout) + + torch.testing.assert_close(q_hf, q_out, atol=atol, rtol=rtol) + torch.testing.assert_close(k_hf, k_out, atol=atol, rtol=rtol) + + +def inverse_interleave_permute_for_rotary(x: torch.Tensor) -> torch.Tensor: + b, h, s, d = x.shape + x = x.view(b, h, s, 2, d // 2) + x = x.transpose(4, 3) + return x.reshape(b, h, s, d) + + +@pytest.mark.parametrize("head_dim", [64, 256]) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.bfloat16, 1e-4, 1e-4), + (torch.float16, 5e-4, 5e-4), + ], + ids=["bfloat16", "float16"], +) +def test_ds_impl_and_hf_impl(dtype, head_dim, atol, rtol): + """ + Ensure Deepseek's interleaved-Q/K RoPE matches HF apply_rotary_pos_emb: + - DS Q/K: [B, N, S, D] channel-interleaved in last dim. + - cos_new/sin_new: [S, D] duplicated real values. + - HF path: Q/K → [B,N,S,D], cos_expand/sin_expand: [B,S,D], unsqueezed at dim=1. + """ + device = "cuda" + batch = 2 + seq_len = 4 + n_head = 3 + + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch, seq_len) + inv_freq = 1.0 / ( + 10000 + ** (torch.arange(0, head_dim // 2, dtype=torch.float32, device=device) / (head_dim // 2)) + ) + positions_range = torch.arange(seq_len, dtype=torch.float32, device=device) + angles = positions_range.unsqueeze(1) * inv_freq.unsqueeze(0) # [seq_len, head_dim//2] + cos_vals = torch.cos(angles) # [seq_len, head_dim//2] + sin_vals = torch.sin(angles) # [seq_len, head_dim//2] + # duplicate to shape [seq_len, head_dim] + cos_new = torch.cat([cos_vals, cos_vals], dim=-1) + sin_new = torch.cat([sin_vals, sin_vals], dim=-1) + + query = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + key = torch.randn(batch, seq_len, n_head, head_dim, dtype=dtype, device=device) + + # HF torch expects inputs of shape [B, N, S, D] + q_for_hf = query.transpose(1, 2).clone() + k_for_hf = key.transpose(1, 2).clone() + cos_expand = cos_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] + sin_expand = sin_new.unsqueeze(0).expand(batch, -1, -1) # [batch, seq_len, head_dim] + q_rotated_hf, k_rotated_hf = apply_rotary_pos_emb_explicit( + q_for_hf, k_for_hf, cos_expand, sin_expand, unsqueeze_dim=1 + ) + q_rotated_hf = q_rotated_hf.transpose(1, 2).to(torch.float32) + k_rotated_hf = k_rotated_hf.transpose(1, 2).to(torch.float32) + + q_for_hf2 = inverse_interleave_permute_for_rotary(q_for_hf.clone()) + k_for_hf2 = inverse_interleave_permute_for_rotary(k_for_hf.clone()) + + # adapted from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L134 + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_ds = emb.cos() + sin_ds = emb.sin() + + torch.testing.assert_close(cos_ds, cos_new, rtol=rtol, atol=atol) + torch.testing.assert_close(sin_ds, sin_new, rtol=rtol, atol=atol) + + q_rotated_hf2, k_rotated_hf2 = apply_rotary_pos_emb_ds( + q_for_hf2, k_for_hf2, cos_new, sin_new, position_ids, unsqueeze_dim=1 + ) + + torch.testing.assert_close(q_rotated_hf2.transpose(1, 2), q_rotated_hf, rtol=rtol, atol=atol) + torch.testing.assert_close(k_rotated_hf2.transpose(1, 2), k_rotated_hf, rtol=rtol, atol=atol) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_matcher.py deleted file mode 100644 index 4dbcecffa3..0000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_matcher.py +++ /dev/null @@ -1,211 +0,0 @@ -import pytest -import torch -from _graph_test_helpers import run_test -from torch.export import Dim - -from tensorrt_llm._torch.auto_deploy.transformations.library.rope import ( - match_rope_v1, - match_rope_v2, -) -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op - -torch.manual_seed(0) - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim: int = 1 -): - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def _precompute_freqs_cis(seq_len: int, head_dim: int, rope_theta: float): - dtype = torch.float32 - inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) - positions = torch.arange(seq_len, dtype=torch.float32) - freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype) - sin = emb.sin().to(dtype) - return cos, sin - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, # Expected shape: (B, seq, head_dim//2) and complex dtype. -): - # Reshape the inputs to pair the last dimension. - xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - # Multiply with frequencies. Note that freqs_cis is expected to broadcast with an extra head dim. - xq_out = torch.view_as_real(xq_complex * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_complex * freqs_cis[:, :, None, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def _precompute_freqs_cis_v2(seq_len: int, head_dim: int, rope_theta: float): - """ - Compute the frequency tensor for the complex multiplication RoPE variant. - Returns a complex tensor of shape (seq_len, head_dim//2). - """ - inv_freq = 1.0 / ( - rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2)) - ) - positions = torch.arange(seq_len, dtype=torch.float32) - angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2) - # Create a complex tensor from magnitude=1 and the computed angles. - freqs_cis = torch.polar(torch.ones_like(angles), angles) - return freqs_cis - - -class RotaryModel(torch.nn.Module): - def __init__( - self, - hidden_size: int, - max_seq_len: int, - num_heads: int, - num_kv_heads: int, - layout: str = "BNSD", - ): - super().__init__() - self.hidden_size = hidden_size - self.max_seq_len = max_seq_len - self.layout = layout - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = hidden_size // num_heads - self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim) - self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - q = self.linear_q(x) - k = self.linear_k(x) - - batch, seq, _ = q.shape - q = q.view(batch, seq, self.num_heads, self.head_dim) - k = k.view(batch, seq, self.num_kv_heads, self.head_dim) - - if self.layout == "BNSD": - q = q.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] - k = k.permute(0, 2, 1, 3).contiguous() - unsqueeze_dim = 1 - else: # BSND - unsqueeze_dim = 2 - - cos, sin = _precompute_freqs_cis(seq, self.head_dim, rope_theta=10000) - cos = cos.to(q.device).unsqueeze(0).expand(batch, -1, -1) - sin = sin.to(q.device).unsqueeze(0).expand(batch, -1, -1) - - q_embed, k_embed = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) - if self.layout == "BNSD": - # [B, N, S, D] -> [B, S, N*D] - q_embed = q_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1) - k_embed = k_embed.permute(0, 2, 1, 3).reshape(batch, seq, -1) - else: # BSND - q_embed = q_embed.reshape(batch, seq, -1) - k_embed = k_embed.reshape(batch, seq, -1) - - output = torch.cat([q_embed, k_embed], dim=-1) - return output.to(torch.float16) - - def get_dynamic_shapes(self): - return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)} - - -class RotaryModelV2(torch.nn.Module): - def __init__(self, hidden_size: int, max_seq_len: int, num_heads: int, num_kv_heads: int): - super().__init__() - self.hidden_size = hidden_size - self.max_seq_len = max_seq_len - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads - self.head_dim = hidden_size // num_heads - - self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim) - self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - batch, seq, _ = x.shape - - q = self.linear_q(x).view(batch, seq, self.num_heads, self.head_dim) - k = self.linear_k(x).view(batch, seq, self.num_kv_heads, self.head_dim) - - freqs_cis = _precompute_freqs_cis_v2(seq, self.head_dim, rope_theta=10000) - freqs_cis = freqs_cis.to(x.device).unsqueeze(0).expand(batch, -1, -1) - - q_embed, k_embed = apply_rotary_emb(q, k, freqs_cis) - - q_embed = q_embed.reshape(batch, seq, -1) - k_embed = k_embed.reshape(batch, seq, -1) - - output = torch.cat([q_embed, k_embed], dim=-1) - return output.to(torch.float16) - - def get_dynamic_shapes(self): - return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)} - - -@pytest.mark.parametrize("layout", ["BNSD", "BSND"]) -@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)]) -@torch.inference_mode() -def test_match_rope(layout, num_heads, num_kv_heads): - batch_size, seq_len = 8, 16 - hidden_size = 512 - max_position_embeddings = seq_len - - model = RotaryModel( - hidden_size, max_position_embeddings, num_heads, num_kv_heads, layout=layout - ).to("cuda", dtype=torch.float16) - x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) - dynamic_shapes = model.get_dynamic_shapes() - - _ = run_test( - model, - x, - match_rope_v1, - lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes), - lambda num_p_og: num_p_og, - atol=1e-3, - rtol=1e-3, - test_load_hook=True, - strict_loading=True, - dynamic_shapes=dynamic_shapes, - ) - - -@pytest.mark.parametrize("num_heads, num_kv_heads", [(8, 8), (8, 4)]) -@torch.inference_mode() -def test_match_rope_v2(num_heads, num_kv_heads): - batch_size, seq_len = 8, 16 - hidden_size = 512 - max_position_embeddings = seq_len - - model = RotaryModelV2(hidden_size, max_position_embeddings, num_heads, num_kv_heads).to( - "cuda", dtype=torch.float16 - ) - - x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) - dynamic_shapes = model.get_dynamic_shapes() - - _ = run_test( - model, - x, - match_rope_v2, - lambda gm: any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes), - lambda num_p_og: num_p_og, - atol=1e-3, - rtol=1e-3, - test_load_hook=True, - strict_loading=True, - dynamic_shapes=dynamic_shapes, - ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py new file mode 100644 index 0000000000..d43ba0cd4b --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_rope_transformation.py @@ -0,0 +1,429 @@ +import pytest +import torch +from _graph_test_helpers import run_test +from _model_test_utils import ( + apply_rotary_pos_emb_complex, + apply_rotary_pos_emb_ds, + apply_rotary_pos_emb_explicit, +) +from torch.export import Dim + +from tensorrt_llm._torch.auto_deploy.transformations.library.rope import ( + match_complex_rope, + match_explicit_rope, + match_rope_layout, + optimize_rope, +) +from tensorrt_llm._torch.auto_deploy.utils.node_utils import extract_output_tuple, is_op + +torch.manual_seed(0) + + +def _precompute_freqs_cis_explicit(seq_len: int, head_dim: int, rope_theta: float): + dtype = torch.float32 + inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + positions = torch.arange(seq_len, dtype=torch.float32) + freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + return cos, sin + + +def _precompute_freqs_cis_complex(seq_len: int, head_dim: int, rope_theta: float): + """ + Compute the frequency tensor for the complex multiplication RoPE variant. + Returns a complex tensor of shape (seq_len, head_dim//2). + """ + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim // 2, dtype=torch.float32) / (head_dim // 2)) + ) + positions = torch.arange(seq_len, dtype=torch.float32) + angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0) # (seq_len, head_dim//2) + # Create a complex tensor from magnitude=1 and the computed angles. + freqs_cis = torch.polar(torch.ones_like(angles), angles) + return freqs_cis + + +class RoPEModel(torch.nn.Module): + def __init__( + self, + hidden_size: int, + max_seq_len: int, + num_heads: int, + num_kv_heads: int, + variant: str = "explicit", # "explicit" or "complex" + mode: str = "match", # "match" or "optimize" + layout: str = "BNSD", # "BNSD" or "BSND" + ): + super().__init__() + self.hidden_size = hidden_size + self.max_seq_len = max_seq_len + self.variant = variant + self.mode = mode + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_heads + self.layout = layout + + self.linear_q = torch.nn.Linear(hidden_size, num_heads * self.head_dim) + self.linear_k = torch.nn.Linear(hidden_size, num_kv_heads * self.head_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, s, _ = x.shape + q = self.linear_q(x) + k = self.linear_k(x) + + if self.variant == "explicit": + # reshape and permute if BNSD layout + q = q.view(b, s, self.num_heads, self.head_dim) + k = k.view(b, s, self.num_kv_heads, self.head_dim) + if self.layout == "BNSD": + q = q.permute(0, 2, 1, 3).contiguous() + k = k.permute(0, 2, 1, 3).contiguous() + unsq_dim = 1 + else: + unsq_dim = 2 + + cos, sin = _precompute_freqs_cis_explicit(s, self.head_dim, rope_theta=10000) + cos = cos.to(x.device).unsqueeze(0).expand(b, -1, -1) + sin = sin.to(x.device).unsqueeze(0).expand(b, -1, -1) + + if self.mode == "match": + q_out, k_out = apply_rotary_pos_emb_explicit(q, k, cos, sin, unsq_dim) + else: # optimize + q_out, k_out = torch.ops.rope.torch_apply_rope_with_explicit_cos_sin( + q, k, cos, sin, unsq_dim + ) + + # revert layout and flatten + if self.layout == "BNSD": + q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1) + k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1) + else: + q_out = q_out.reshape(b, s, -1) + k_out = k_out.reshape(b, s, -1) + + else: # complex variant + q = q.view(b, s, self.num_heads, self.head_dim) + k = k.view(b, s, self.num_kv_heads, self.head_dim) + if self.layout == "BNSD": + q = q.permute(0, 2, 1, 3).contiguous() + k = k.permute(0, 2, 1, 3).contiguous() + unsq_dim = 1 + else: + unsq_dim = 2 + + freqs = _precompute_freqs_cis_complex(s, self.head_dim, rope_theta=10000) + freqs = freqs.to(x.device).unsqueeze(0).expand(b, -1, -1) + + if self.mode == "match": + q_out, k_out = apply_rotary_pos_emb_complex(q, k, freqs, unsq_dim) + else: + q_out, k_out = torch.ops.rope.torch_apply_rope_with_complex_freqs( + q, k, freqs, unsq_dim + ) + + # revert layout and flatten + if self.layout == "BNSD": + q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1) + k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1) + else: + q_out = q_out.reshape(b, s, -1) + k_out = k_out.reshape(b, s, -1) + + out = torch.cat([q_out, k_out], dim=-1) + return out.to(torch.float16) if self.mode == "match" else out + + def get_dynamic_shapes(self): + return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)} + + +@pytest.mark.parametrize( + "transformation,variant,layout,batch_size,seq_len,num_heads,num_kv_heads,atol,rtol, target_layout", + [ + ("match", "explicit", "BNSD", 8, 16, 8, 8, 1e-2, 1e-2, None), + ("match", "explicit", "BSND", 8, 16, 8, 4, 1e-2, 1e-2, None), + ("match", "complex", "BNSD", 8, 16, 8, 8, 1e-3, 1e-3, None), + ("match", "complex", "BSND", 8, 16, 8, 4, 1e-3, 1e-3, None), + ("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"), + ("match_layout", "explicit", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BNSD"), + ("match_layout", "complex", "BNSD", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"), + ("match_layout", "complex", "BSND", 4, 12, 8, 8, 1e-3, 1e-3, "BSND"), + pytest.param( + "optimize", + "explicit", + "BNSD", + 4, + 12, + 8, + 8, + 1e-3, + 1e-3, + None, + marks=pytest.mark.xfail( + reason="flashinfer op does not support BNSD layout", strict=True + ), + ), + ("optimize", "explicit", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None), + pytest.param( + "optimize", + "complex", + "BNSD", + 4, + 12, + 8, + 8, + 1e-3, + 1e-3, + None, + marks=pytest.mark.xfail( + reason="flashinfer op does not support BNSD layout", strict=True + ), + ), + ("optimize", "complex", "BSND", 4, 12, 8, 4, 1e-3, 1e-3, None), + ], +) +@torch.inference_mode() +def test_rope_variants( + transformation, + variant, + layout, + batch_size, + seq_len, + num_heads, + num_kv_heads, + atol, + rtol, + target_layout, +): + hidden_size = 512 + model = RoPEModel( + hidden_size, + seq_len, + num_heads, + num_kv_heads, + variant=variant, + mode=transformation, + layout=layout or "BNSD", + ).to("cuda", torch.float16) + x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) + dyn = model.get_dynamic_shapes() + + if transformation == "match": + fn = match_explicit_rope if variant == "explicit" else match_complex_rope + check_op = ( + torch.ops.rope.torch_apply_rope_with_explicit_cos_sin + if variant == "explicit" + else torch.ops.rope.torch_apply_rope_with_complex_freqs + ) + + def checker(gm): + return any(is_op(n, check_op) for n in gm.graph.nodes) + + elif transformation == "match_layout": + fn = match_rope_layout + + def checker(gm): + for n in gm.graph.nodes: + if is_op( + n, + { + torch.ops.rope.torch_apply_rope_with_explicit_cos_sin, + torch.ops.rope.torch_apply_rope_with_complex_freqs, + }, + ): + q_arg, k_arg, *rest = n.args + if not ( + is_op(q_arg, torch.ops.aten.contiguous) + and is_op(k_arg, torch.ops.aten.contiguous) + ): + matched = False + break + + old_q, old_k = extract_output_tuple(n, 2) + if old_q is None or old_k is None: + matched = False + break + q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users) + k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users) + matched = q_transposed and k_transposed + + return matched if layout != target_layout else not matched + + else: + fn = optimize_rope + + def checker(gm): + return any(is_op(n, torch.ops.rope.flashinfer) for n in gm.graph.nodes) + + if target_layout: + _ = run_test( + model, + x, + fn, + checker, + lambda n: n, + atol, # atol + rtol, # rtol + True, # test_load_hook + True, # strict_loading + dyn, # dynamic_shapes + target_layout, + ) + else: + _ = run_test( + model, + x, + fn, + checker, + lambda n: n, + atol, # atol + rtol, # rtol + True, # test_load_hook + True, # strict_loading + dyn, # dynamic_shapes + ) + + +class DSRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + t = torch.arange(max_position_embeddings, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def forward(self, x, seq_len=None): + # returns [seq_len, head_dim] cos & sin + return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype) + + +class DSModel(torch.nn.Module): + def __init__(self, hidden_size, max_seq, n_head, n_kv, layout="BNSD", mode: str = "match"): + super().__init__() + self.hdim = hidden_size // n_head + self.layout = layout + self.q_lin = torch.nn.Linear(hidden_size, n_head * self.hdim) + self.k_lin = torch.nn.Linear(hidden_size, n_kv * self.hdim) + self.rotary = DSRotaryEmbedding(self.hdim, max_seq, base=10000, device="cuda") + self.mode = mode # "match" or "optimize" + + def forward(self, x): + b, s, _ = x.shape + q = self.q_lin(x).view(b, s, -1, self.hdim) + k = self.k_lin(x).view(b, s, -1, self.hdim) + if self.layout == "BNSD": + # to [B, N, S, D] + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + unsq_dim = 1 + else: + unsq_dim = 2 + cos, sin = self.rotary(x, seq_len=s) + # build position_ids [B, S] + pos_ids = torch.arange(s, device=x.device).unsqueeze(0).expand(b, s) + if self.mode == "match": + q_out, k_out = apply_rotary_pos_emb_ds(q, k, cos, sin, pos_ids, unsqueeze_dim=unsq_dim) + else: + cos = cos[pos_ids] + sin = sin[pos_ids] + q_out, k_out = torch.ops.rope.torch_apply_rope_with_qk_interleaving( + q, k, cos, sin, unsq_dim + ) + if self.layout == "BNSD": + # back to [B, S, N*D] + q_out = q_out.permute(0, 2, 1, 3).reshape(b, s, -1) + k_out = k_out.permute(0, 2, 1, 3).reshape(b, s, -1) + else: + q_out = q_out.reshape(b, s, -1) + k_out = k_out.reshape(b, s, -1) + return torch.cat([q_out, k_out], dim=-1) + + def get_dynamic_shapes(self): + return {0: Dim("batch_size", max=8), 1: Dim("seq_len", max=16)} + + +@pytest.mark.parametrize( + "layout,num_heads,num_kv_heads,mode, target_layout", + [ + ("BNSD", 8, 8, "match", None), + ("BSND", 8, 4, "match", None), + ("BNSD", 8, 8, "match_layout", "BNSD"), + ("BSND", 8, 4, "match_layout", "BNSD"), + ("BSND", 8, 4, "match_layout", "BSND"), + ("BNSD", 8, 4, "match_layout", "BSND"), + ], +) +@torch.inference_mode() +def test_match_and_layout_deepseek(layout, num_heads, num_kv_heads, mode, target_layout): + batch, seq, hid = 4, 12, 512 + model = DSModel(hid, 16, num_heads, num_kv_heads, layout=layout, mode=mode) + model = model.to("cuda", torch.float16) + + x = torch.randn(batch, seq, hid, device="cuda", dtype=torch.float16) + dynamic_shapes = model.get_dynamic_shapes() + + if mode == "match": + transform = match_explicit_rope + + def checker(gm): + return any( + is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving) + for n in gm.graph.nodes + ) + + else: # mode == "match_layout" + transform = match_rope_layout + + def checker(gm): + for n in gm.graph.nodes: + if is_op(n, torch.ops.rope.torch_apply_rope_with_qk_interleaving): + q_arg, k_arg, *rest = n.args + if not ( + is_op(q_arg, torch.ops.aten.contiguous) + and is_op(k_arg, torch.ops.aten.contiguous) + ): + matched = False + break + + old_q, old_k = extract_output_tuple(n, 2) + if old_q is None or old_k is None: + matched = False + break + q_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_q.users) + k_transposed = any(is_op(u, torch.ops.aten.transpose) for u in old_k.users) + matched = q_transposed and k_transposed + + return matched if layout != target_layout else not matched + + if target_layout: + _ = run_test( + model, + x, + transform, + checker, + lambda num_p: num_p, + 1e-3, # atol + 1e-3, # rtol + True, # test_load_hook + True, # strict_loading + dynamic_shapes, # dynamic_shapes + target_layout, + ) + else: + _ = run_test( + model, + x, + transform, + checker, + lambda num_p: num_p, + 1e-3, # atol + 1e-3, # rtol + True, # test_load_hook + True, # strict_loading + dynamic_shapes, # dynamic_shapes + )