[ROCm][GPT-OSS] Fuse RoPE + static Q FP8 quant on fused RoPE+KV path (#42832)

Signed-off-by: Aakif Nawaz <aakif.nawaz@amd.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
akii96
2026-06-06 00:22:19 +03:00
committed by GitHub
parent c73b0d0db9
commit 4200f62147
2 changed files with 451 additions and 1 deletions
@@ -3,11 +3,13 @@
import pytest
import torch
from torch._higher_order_ops import auto_functionalized
import vllm.config
from tests.compile.backend import TestBackend
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.passes.fusion import rope_kvcache_fusion
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
@@ -24,6 +26,7 @@ from vllm.config import (
PassConfig,
VllmConfig,
)
from vllm.config.utils import Range
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
@@ -40,6 +43,20 @@ VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
FP8_DTYPE = current_platform.fp8_dtype()
def test_rope_kvcache_fusion_default_keeps_large_ranges_unfused():
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(fuse_rope_kvcache=True),
),
)
fusion_pass = RopeKVCacheFusionPass(vllm_config)
assert fusion_pass.is_applicable_for_range(Range(1, 256))
assert not fusion_pass.is_applicable_for_range(Range(257, 11650))
assert not fusion_pass.is_applicable_for_range(Range(11651, 16384))
class QKRoPEKVCacheTestModel(torch.nn.Module):
def __init__(
self,
@@ -184,6 +201,51 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
class QKRoPEStaticQKVCacheTestModel(QKRoPEKVCacheTestModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.q_scale = torch.ones((), dtype=torch.float32, device=self.device)
def forward(
self, qkv: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Create copy so inplace ops do not modify the original tensors
qkv = qkv.clone()
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
q_fp8 = torch.empty(q.shape, device=q.device, dtype=FP8_DTYPE)
_, q_fp8 = auto_functionalized(
torch.ops._C.static_scaled_fp8_quant.default,
result=q_fp8,
input=q,
scale=self.q_scale,
group_shape=(-1, -1),
)
q = q_fp8.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
k, v, _encode_layer_name(self.layer_name)
)
return q, k, v, kv_cache_dummy_dep
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rope_custom_op:
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
else:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
return ops
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
@pytest.mark.parametrize(
"attn_backend",
[
@@ -326,3 +388,198 @@ def test_rope_kvcache_fusion(
atol=ATOL,
rtol=RTOL,
)
@pytest.mark.parametrize(
"attn_backend",
[AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN],
)
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
@pytest.mark.parametrize("num_heads", [64])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("is_neox", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.skipif(
not is_aiter_found_and_supported(),
reason="Only test on ROCm with AITER installed and supported",
)
@pytest.mark.skipif(
not hasattr(torch.ops._C, "static_scaled_fp8_quant"),
reason="static fp8 quant op not available on this build",
)
def test_rope_static_qquant_kvcache_fusion(
attn_backend: AttentionBackendEnum,
enable_rope_custom_op: bool,
enable_aiter_triton_rope: bool,
num_heads: int,
num_kv_heads: int,
head_size: int,
block_size: int,
is_neox: bool,
dtype: torch.dtype,
kv_cache_dtype: str,
monkeypatch: pytest.MonkeyPatch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(0)
custom_ops: list[str] = []
if enable_rope_custom_op:
custom_ops.append("+rotary_embedding")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
cache_config=CacheConfig(
block_size=block_size,
cache_dtype=kv_cache_dtype,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(
fuse_rope_kvcache=True,
eliminate_noops=True,
),
),
)
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
m.setenv("VLLM_ROCM_USE_AITER", "1")
m.setenv(
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
)
rocm_aiter_ops.refresh_env_variables()
model = QKRoPEStaticQKVCacheTestModel(
vllm_config=vllm_config,
attn_backend=attn_backend,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
is_neox=is_neox,
dtype=dtype,
device=torch.get_default_device(),
)
fusion_pass = RopeKVCacheFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
SplitCoalescingPass(vllm_config),
ScatterSplitReplacementPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
T = 5
qkv = torch.randn(
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
)
pos = torch.arange(T, dtype=torch.long)
qkv_unfused = qkv.clone()
pos_unfused = pos.clone()
with set_forward_context(None, vllm_config):
forward_context = get_forward_context()
attn_metadata = model.build_attn_metadata(T)
forward_context.slot_mapping = {
model.layer_name: attn_metadata.slot_mapping
}
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache
del dummy
torch._dynamo.mark_dynamic(qkv, 0)
torch._dynamo.mark_dynamic(pos, 0)
with set_forward_context(None, vllm_config):
model_fused = torch.compile(model, backend=backend)
forward_context = get_forward_context()
attn_metadata = model_fused.build_attn_metadata(T)
forward_context.slot_mapping = {
model.layer_name: attn_metadata.slot_mapping
}
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache
del dummy
assert fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())
static_quant_pre = backend.op_count(
torch.ops._C.static_scaled_fp8_quant.default, before=True
)
static_quant_post = backend.op_count(
torch.ops._C.static_scaled_fp8_quant.default
)
assert static_quant_pre > 0
# The replacement still emits static quant, so count is expected to
# remain non-zero after fusion.
assert static_quant_post > 0
# Negative control: without the static-Q pattern, the generic RoPE+KV
# pattern cannot match this rope -> static-quant -> kv graph, so the
# fusion above is attributable solely to RopeStaticQQuantKVCachePattern.
# This is a structural property independent of the rope/dtype/neox axes,
# so run the (extra compile) check only once on a representative combo.
if is_neox and enable_aiter_triton_rope and kv_cache_dtype == "auto":
m.setattr(
rope_kvcache_fusion,
"_supports_static_q_fp8_quant_fusion",
lambda: False,
)
generic_pass = RopeKVCacheFusionPass(vllm_config)
generic_backend = TestBackend(
NoOpEliminationPass(vllm_config),
SplitCoalescingPass(vllm_config),
ScatterSplitReplacementPass(vllm_config),
generic_pass,
PostCleanupPass(vllm_config),
)
# Reset dynamo so the model is recompiled through generic_backend
# instead of reusing the cached compilation from above.
torch._dynamo.reset()
with set_forward_context(None, vllm_config):
model_generic = torch.compile(model, backend=generic_backend)
forward_context = get_forward_context()
attn_metadata = model_generic.build_attn_metadata(T)
forward_context.slot_mapping = {
model.layer_name: attn_metadata.slot_mapping
}
model_generic(qkv, pos)
# op_count reads the post-pass graph, so it also confirms the pass ran
# (a no-op pass would raise instead of silently passing on count 0).
assert generic_pass.matched_count == 0
assert (
generic_backend.op_count(
torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
)
== 0
)
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(
q_unfused.to(torch.float32),
q_fused.to(torch.float32),
atol=ATOL,
rtol=RTOL,
)
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(
kv_cache_unfused.view(dtype),
kv_cache_fused.view(dtype),
atol=ATOL,
rtol=RTOL,
)