mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][GPT-OSS] Fuse RoPE + static Q FP8 quant on fused RoPE+KV path (#42832)
Signed-off-by: Aakif Nawaz <aakif.nawaz@amd.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,11 +3,13 @@
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
|
||||
import vllm.config
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion import rope_kvcache_fusion
|
||||
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
|
||||
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
@@ -24,6 +26,7 @@ from vllm.config import (
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.config.utils import Range
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
@@ -40,6 +43,20 @@ VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def test_rope_kvcache_fusion_default_keeps_large_ranges_unfused():
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(fuse_rope_kvcache=True),
|
||||
),
|
||||
)
|
||||
fusion_pass = RopeKVCacheFusionPass(vllm_config)
|
||||
|
||||
assert fusion_pass.is_applicable_for_range(Range(1, 256))
|
||||
assert not fusion_pass.is_applicable_for_range(Range(257, 11650))
|
||||
assert not fusion_pass.is_applicable_for_range(Range(11651, 16384))
|
||||
|
||||
|
||||
class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -184,6 +201,51 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
|
||||
|
||||
|
||||
class QKRoPEStaticQKVCacheTestModel(QKRoPEKVCacheTestModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.q_scale = torch.ones((), dtype=torch.float32, device=self.device)
|
||||
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, positions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Create copy so inplace ops do not modify the original tensors
|
||||
qkv = qkv.clone()
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
q_fp8 = torch.empty(q.shape, device=q.device, dtype=FP8_DTYPE)
|
||||
_, q_fp8 = auto_functionalized(
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=q_fp8,
|
||||
input=q,
|
||||
scale=self.q_scale,
|
||||
group_shape=(-1, -1),
|
||||
)
|
||||
q = q_fp8.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size)
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
k, v, _encode_layer_name(self.layer_name)
|
||||
)
|
||||
return q, k, v, kv_cache_dummy_dep
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rope_custom_op:
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
else:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[
|
||||
@@ -326,3 +388,198 @@ def test_rope_kvcache_fusion(
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True])
|
||||
@pytest.mark.parametrize("enable_aiter_triton_rope", [True, False])
|
||||
@pytest.mark.parametrize("num_heads", [64])
|
||||
@pytest.mark.parametrize("num_kv_heads", [8])
|
||||
@pytest.mark.parametrize("head_size", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.skipif(
|
||||
not is_aiter_found_and_supported(),
|
||||
reason="Only test on ROCm with AITER installed and supported",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "static_scaled_fp8_quant"),
|
||||
reason="static fp8 quant op not available on this build",
|
||||
)
|
||||
def test_rope_static_qquant_kvcache_fusion(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
enable_rope_custom_op: bool,
|
||||
enable_aiter_triton_rope: bool,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
is_neox: bool,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
cache_config=CacheConfig(
|
||||
block_size=block_size,
|
||||
cache_dtype=kv_cache_dtype,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
fuse_rope_kvcache=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
m.setenv(
|
||||
"VLLM_ROCM_USE_AITER_TRITON_ROPE", "1" if enable_aiter_triton_rope else "0"
|
||||
)
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
model = QKRoPEStaticQKVCacheTestModel(
|
||||
vllm_config=vllm_config,
|
||||
attn_backend=attn_backend,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
is_neox=is_neox,
|
||||
dtype=dtype,
|
||||
device=torch.get_default_device(),
|
||||
)
|
||||
|
||||
fusion_pass = RopeKVCacheFusionPass(vllm_config)
|
||||
passes = [
|
||||
NoOpEliminationPass(vllm_config),
|
||||
SplitCoalescingPass(vllm_config),
|
||||
ScatterSplitReplacementPass(vllm_config),
|
||||
fusion_pass,
|
||||
PostCleanupPass(vllm_config),
|
||||
]
|
||||
backend = TestBackend(*passes)
|
||||
|
||||
T = 5
|
||||
qkv = torch.randn(
|
||||
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
|
||||
)
|
||||
pos = torch.arange(T, dtype=torch.long)
|
||||
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
with set_forward_context(None, vllm_config):
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = model.build_attn_metadata(T)
|
||||
forward_context.slot_mapping = {
|
||||
model.layer_name: attn_metadata.slot_mapping
|
||||
}
|
||||
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_unfused = attn_layer.kv_cache
|
||||
del dummy
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
with set_forward_context(None, vllm_config):
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = model_fused.build_attn_metadata(T)
|
||||
forward_context.slot_mapping = {
|
||||
model.layer_name: attn_metadata.slot_mapping
|
||||
}
|
||||
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_fused = attn_layer.kv_cache
|
||||
del dummy
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
static_quant_pre = backend.op_count(
|
||||
torch.ops._C.static_scaled_fp8_quant.default, before=True
|
||||
)
|
||||
static_quant_post = backend.op_count(
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
)
|
||||
assert static_quant_pre > 0
|
||||
# The replacement still emits static quant, so count is expected to
|
||||
# remain non-zero after fusion.
|
||||
assert static_quant_post > 0
|
||||
|
||||
# Negative control: without the static-Q pattern, the generic RoPE+KV
|
||||
# pattern cannot match this rope -> static-quant -> kv graph, so the
|
||||
# fusion above is attributable solely to RopeStaticQQuantKVCachePattern.
|
||||
# This is a structural property independent of the rope/dtype/neox axes,
|
||||
# so run the (extra compile) check only once on a representative combo.
|
||||
if is_neox and enable_aiter_triton_rope and kv_cache_dtype == "auto":
|
||||
m.setattr(
|
||||
rope_kvcache_fusion,
|
||||
"_supports_static_q_fp8_quant_fusion",
|
||||
lambda: False,
|
||||
)
|
||||
generic_pass = RopeKVCacheFusionPass(vllm_config)
|
||||
generic_backend = TestBackend(
|
||||
NoOpEliminationPass(vllm_config),
|
||||
SplitCoalescingPass(vllm_config),
|
||||
ScatterSplitReplacementPass(vllm_config),
|
||||
generic_pass,
|
||||
PostCleanupPass(vllm_config),
|
||||
)
|
||||
# Reset dynamo so the model is recompiled through generic_backend
|
||||
# instead of reusing the cached compilation from above.
|
||||
torch._dynamo.reset()
|
||||
with set_forward_context(None, vllm_config):
|
||||
model_generic = torch.compile(model, backend=generic_backend)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = model_generic.build_attn_metadata(T)
|
||||
forward_context.slot_mapping = {
|
||||
model.layer_name: attn_metadata.slot_mapping
|
||||
}
|
||||
model_generic(qkv, pos)
|
||||
# op_count reads the post-pass graph, so it also confirms the pass ran
|
||||
# (a no-op pass would raise instead of silently passing on count 0).
|
||||
assert generic_pass.matched_count == 0
|
||||
assert (
|
||||
generic_backend.op_count(
|
||||
torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(
|
||||
q_unfused.to(torch.float32),
|
||||
q_fused.to(torch.float32),
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(
|
||||
kv_cache_unfused.view(dtype),
|
||||
kv_cache_fused.view(dtype),
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user