mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][DSV4] Enable Tilelang MHC replacing torch/triton mhc (#43679)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -16,3 +16,4 @@ wheel
|
||||
jinja2>=3.1.6
|
||||
amdsmi==7.0.2
|
||||
timm>=1.0.17
|
||||
tilelang==0.1.10
|
||||
|
||||
@@ -22,3 +22,4 @@ timm>=1.0.17
|
||||
# amd-quark: required for Quark quantization on ROCm
|
||||
# To be consistent with test_quark.py
|
||||
amd-quark>=0.8.99
|
||||
tilelang==0.1.10
|
||||
|
||||
@@ -43,6 +43,7 @@ schemathesis>=3.39.15 # Required for openai schema test
|
||||
# quantization
|
||||
bitsandbytes==0.49.2
|
||||
buildkite-test-collector==0.1.9
|
||||
tilelang==0.1.10
|
||||
|
||||
genai_perf>=0.0.8
|
||||
tritonclient>=2.51.0
|
||||
|
||||
@@ -43,7 +43,9 @@ anyio==4.13.0
|
||||
# starlette
|
||||
# watchfiles
|
||||
apache-tvm-ffi==0.1.10
|
||||
# via xgrammar
|
||||
# via
|
||||
# tilelang
|
||||
# xgrammar
|
||||
arctic-inference==0.1.1
|
||||
# via -r requirements/test/rocm.in
|
||||
argcomplete==3.6.3
|
||||
@@ -129,7 +131,9 @@ click==8.3.1
|
||||
# typer
|
||||
# uvicorn
|
||||
cloudpickle==3.1.2
|
||||
# via -r requirements/test/../common.txt
|
||||
# via
|
||||
# -r requirements/test/../common.txt
|
||||
# tilelang
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# perceptron
|
||||
@@ -511,6 +515,8 @@ mistral-common==1.11.2
|
||||
# -c requirements/common.txt
|
||||
# -r requirements/test/../common.txt
|
||||
# -r requirements/test/rocm.in
|
||||
ml-dtypes==0.5.4
|
||||
# via tilelang
|
||||
model-hosting-container-standards==0.1.14
|
||||
# via
|
||||
# -c requirements/common.txt
|
||||
@@ -587,6 +593,7 @@ numpy==2.2.6
|
||||
# lm-eval
|
||||
# matplotlib
|
||||
# mistral-common
|
||||
# ml-dtypes
|
||||
# mteb
|
||||
# numba
|
||||
# opencv-python-headless
|
||||
@@ -610,6 +617,7 @@ numpy==2.2.6
|
||||
# statsmodels
|
||||
# tensorizer
|
||||
# tifffile
|
||||
# tilelang
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
@@ -811,6 +819,7 @@ psutil==7.2.2
|
||||
# accelerate
|
||||
# peft
|
||||
# tensorizer
|
||||
# tilelang
|
||||
py==1.11.0
|
||||
# via pytest-forked
|
||||
py-cpuinfo==9.0.0
|
||||
@@ -1192,6 +1201,10 @@ tiktoken==0.12.0
|
||||
# gpt-oss
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
tilelang==0.1.10
|
||||
# via
|
||||
# -c requirements/rocm.txt
|
||||
# -r requirements/test/rocm.in
|
||||
timm==1.0.17
|
||||
# via
|
||||
# -c requirements/rocm.txt
|
||||
@@ -1208,6 +1221,8 @@ tomli==2.4.0
|
||||
# via schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch-c-dlpack-ext==0.1.5
|
||||
# via tilelang
|
||||
tqdm==4.67.3
|
||||
# via
|
||||
# -r requirements/test/../common.txt
|
||||
@@ -1225,6 +1240,7 @@ tqdm==4.67.3
|
||||
# pqdm
|
||||
# segmentation-models-pytorch
|
||||
# sentence-transformers
|
||||
# tilelang
|
||||
# transformers
|
||||
transformers==5.5.3
|
||||
# via
|
||||
@@ -1293,6 +1309,7 @@ typing-extensions==4.15.0
|
||||
# sentence-transformers
|
||||
# sqlalchemy
|
||||
# starlette
|
||||
# tilelang
|
||||
# torch
|
||||
# typeguard
|
||||
# typing-inspection
|
||||
@@ -1359,6 +1376,8 @@ yarl==1.23.0
|
||||
# via
|
||||
# aiohttp
|
||||
# schemathesis
|
||||
z3-solver==4.15.4.0
|
||||
# via tilelang
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
|
||||
|
||||
@@ -4,7 +4,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.kernels.mhc # noqa: F401
|
||||
from vllm.model_executor.kernels.mhc.tilelang import (
|
||||
_tilelang_hc_prenorm_gemm,
|
||||
_torch_hc_prenorm_gemm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
@@ -92,8 +97,128 @@ def hc_head_ref(
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="CUDA required",
|
||||
not (current_platform.is_cuda_alike() and has_tilelang()),
|
||||
reason="CUDA or ROCm and tilelang required",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [4096, 7168])
|
||||
@pytest.mark.parametrize("hc_mult", [4])
|
||||
def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult):
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(0)
|
||||
|
||||
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
|
||||
hc_mult2 = hc_mult * hc_mult
|
||||
hc_mult3 = 2 * hc_mult + hc_mult2
|
||||
fn = (
|
||||
torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float)
|
||||
* 1e-4
|
||||
* (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1))
|
||||
).flatten(1, 2)
|
||||
hc_scale = torch.randn((3,), dtype=torch.float) * 0.1
|
||||
hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1
|
||||
|
||||
hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6
|
||||
sinkhorn_repeat = 20
|
||||
hc_post_alpha = 1.0
|
||||
|
||||
ref = mhc_pre_ref(
|
||||
residual,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_alpha,
|
||||
sinkhorn_repeat,
|
||||
)
|
||||
out = torch.ops.vllm.mhc_pre_tilelang(
|
||||
residual,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_alpha,
|
||||
sinkhorn_repeat,
|
||||
)
|
||||
|
||||
for actual, expected in zip(out, ref, strict=True):
|
||||
torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() and has_tilelang()),
|
||||
reason="CUDA or ROCm and tilelang required",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("num_tokens", "hidden_size"),
|
||||
[
|
||||
(1, 1280),
|
||||
(512, 1280),
|
||||
(2048, 1280),
|
||||
(1, 4096),
|
||||
(64, 4096),
|
||||
(512, 4096),
|
||||
(2048, 4096),
|
||||
(1, 7168),
|
||||
(64, 7168),
|
||||
(512, 7168),
|
||||
(2048, 7168),
|
||||
],
|
||||
)
|
||||
def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size):
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(0)
|
||||
|
||||
hc_mult = 4
|
||||
hc_mult3 = 2 * hc_mult + hc_mult * hc_mult
|
||||
x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16)
|
||||
fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
|
||||
out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32)
|
||||
sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32)
|
||||
out = torch.empty_like(out_ref)
|
||||
sqrsum = torch.empty_like(sqrsum_ref)
|
||||
|
||||
_torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref)
|
||||
_tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult)
|
||||
|
||||
torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4)
|
||||
torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() and has_tilelang()),
|
||||
reason="CUDA or ROCm and tilelang required",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [4096, 7168])
|
||||
@pytest.mark.parametrize("hc_mult", [4])
|
||||
def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult):
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(0)
|
||||
|
||||
x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16)
|
||||
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
|
||||
post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32)
|
||||
comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32)
|
||||
|
||||
ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix)
|
||||
out = torch.ops.vllm.mhc_post_tilelang(
|
||||
x,
|
||||
residual,
|
||||
post_layer_mix,
|
||||
comb_res_mix,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() and has_tilelang()),
|
||||
reason="CUDA or ROCm and tilelang required",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [4096, 7168])
|
||||
@@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult):
|
||||
|
||||
out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
|
||||
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_cuda_alike() and has_tilelang()),
|
||||
reason="CUDA or ROCm and tilelang required",
|
||||
)
|
||||
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [4096, 7168])
|
||||
@pytest.mark.parametrize("hc_mult", [4])
|
||||
def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult):
|
||||
torch.set_default_device(DEVICE)
|
||||
set_random_seed(0)
|
||||
|
||||
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
|
||||
fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
|
||||
hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
|
||||
hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
|
||||
rms_eps = hc_eps = 1e-6
|
||||
|
||||
out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
|
||||
out.fill_(float("nan"))
|
||||
|
||||
result = torch.ops.vllm.hc_head_fused_kernel_tilelang(
|
||||
residual,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert not torch.isnan(out).any()
|
||||
|
||||
out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
|
||||
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
|
||||
|
||||
+224
-40
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,8 +10,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
# tilelang is only available on CUDA platforms
|
||||
if TYPE_CHECKING or current_platform.is_cuda():
|
||||
# TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so
|
||||
# registering the Python wrapper modules does not require TileLang everywhere.
|
||||
if TYPE_CHECKING or current_platform.is_cuda_alike():
|
||||
if not has_tilelang():
|
||||
raise ImportError(
|
||||
"tilelang is required for mhc but is not installed. Install it with "
|
||||
@@ -23,6 +24,8 @@ else:
|
||||
tilelang = None # type: ignore[assignment]
|
||||
T = None # type: ignore[assignment]
|
||||
|
||||
ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda()
|
||||
|
||||
|
||||
@cache
|
||||
def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
|
||||
@@ -37,12 +40,17 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
|
||||
return split_k
|
||||
|
||||
|
||||
pass_configs: dict[tilelang.PassConfigKey, Any] = {
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
}
|
||||
|
||||
if current_platform.is_cuda():
|
||||
pass_configs[tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL] = 10
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def mhc_pre_big_fuse_tilelang(
|
||||
gemm_out_mul,
|
||||
@@ -78,7 +86,8 @@ def mhc_pre_big_fuse_tilelang(
|
||||
layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
|
||||
|
||||
with T.Kernel(num_tokens, threads=96) as i:
|
||||
T.pdl_sync()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
##################################################################
|
||||
# _pre_norm_fn_fwd_norm
|
||||
rms = T.alloc_fragment(1, T.float32)
|
||||
@@ -174,18 +183,16 @@ def mhc_pre_big_fuse_tilelang(
|
||||
ol[i1_h] += pre * xl[i_hc, i1_h]
|
||||
|
||||
T.copy(ol, layer_input[i, i0_h * hidden_block])
|
||||
T.pdl_trigger()
|
||||
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
# Copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/mhc.py#L478
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def mhc_pre_big_fuse_with_norm_tilelang(
|
||||
gemm_out_mul,
|
||||
@@ -230,7 +237,8 @@ def mhc_pre_big_fuse_with_norm_tilelang(
|
||||
T.clear(mixes)
|
||||
rms[0] = 0
|
||||
|
||||
T.pdl_sync()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
|
||||
for i_split in T.serial(n_splits):
|
||||
rms[0] += gemm_out_sqrsum[i_split, i]
|
||||
@@ -341,15 +349,12 @@ def mhc_pre_big_fuse_with_norm_tilelang(
|
||||
|
||||
T.copy(ol, layer_input[i, i0_h * hidden_block])
|
||||
|
||||
T.pdl_trigger()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def mhc_fused_tilelang(
|
||||
comb_mix,
|
||||
@@ -390,8 +395,8 @@ def mhc_fused_tilelang(
|
||||
|
||||
with T.Kernel(m, n_tiles, split_k, threads=n_thr) as (i_n, i_nt, i_ks):
|
||||
tid = T.get_thread_binding()
|
||||
warp_id = T.get_warp_idx()
|
||||
lane = T.get_lane_idx()
|
||||
warp_id = tid // 32
|
||||
lane = tid % 32
|
||||
|
||||
s_warp = T.alloc_shared((num_warps, tile_n + 1), T.float32)
|
||||
s_post = T.alloc_shared((hc,), T.float32)
|
||||
@@ -407,7 +412,8 @@ def mhc_fused_tilelang(
|
||||
T.clear(sqr)
|
||||
h_split_start = i_ks * h_per_split
|
||||
|
||||
T.pdl_sync()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
|
||||
T.copy(post_mix[i_n, 0], s_post)
|
||||
T.copy(comb_mix[i_n, 0, 0], s_comb)
|
||||
@@ -466,15 +472,12 @@ def mhc_fused_tilelang(
|
||||
v2 += s_warp[w, tile_n]
|
||||
rp_out[i_ks, i_n] = v2
|
||||
|
||||
T.pdl_trigger()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def mhc_post_tilelang(
|
||||
a,
|
||||
@@ -507,7 +510,8 @@ def mhc_post_tilelang(
|
||||
|
||||
a_local = T.alloc_fragment((hc, hc), T.float32)
|
||||
c_local = T.alloc_fragment(hc, T.float32)
|
||||
T.pdl_sync()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
T.copy(a[i_n, 0, 0], a_local)
|
||||
T.copy(c[i_n, 0], c_local)
|
||||
|
||||
@@ -523,15 +527,193 @@ def mhc_post_tilelang(
|
||||
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
|
||||
|
||||
T.copy(x_local, x[i_n, 0, i0_h * h_blk])
|
||||
T.pdl_trigger()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs={
|
||||
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
|
||||
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
|
||||
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
|
||||
},
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def hc_prenorm_gemm_tilelang(
|
||||
x,
|
||||
fn,
|
||||
out,
|
||||
sqrsum,
|
||||
hidden_size: int,
|
||||
hc_mult: int = 4,
|
||||
n_out: int = 24,
|
||||
n_thr: int = 512,
|
||||
tile_n: int = 12,
|
||||
n_splits: int = 1,
|
||||
) -> tilelang.JITKernel:
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
hc_hidden_size = hc_mult * hidden_size
|
||||
k_per_split = hc_hidden_size // n_splits
|
||||
k_iters = k_per_split // n_thr
|
||||
n_tiles = T.ceildiv(n_out, tile_n)
|
||||
|
||||
x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type]
|
||||
fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type]
|
||||
out: T.Tensor((n_splits, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type]
|
||||
sqrsum: T.Tensor((n_splits, num_tokens), T.float32) # type: ignore[no-redef, valid-type]
|
||||
|
||||
with T.Kernel(num_tokens, n_tiles, n_splits, threads=n_thr) as (
|
||||
i_n,
|
||||
i_t,
|
||||
i_s,
|
||||
):
|
||||
tid = T.get_thread_binding()
|
||||
acc = T.alloc_local((tile_n,), T.float32)
|
||||
sqr = T.alloc_local((1,), T.float32)
|
||||
T.clear(acc)
|
||||
T.clear(sqr)
|
||||
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
|
||||
for it in T.serial(k_iters):
|
||||
i_k = i_s * k_per_split + it * n_thr + tid
|
||||
x_val = x[i_n, i_k]
|
||||
for i_o in T.unroll(tile_n):
|
||||
out_idx = i_t * tile_n + i_o
|
||||
if out_idx < n_out:
|
||||
acc[i_o] += x_val * fn[out_idx, i_k]
|
||||
if i_t == 0:
|
||||
sqr[0] += x_val * x_val
|
||||
|
||||
for i_o in T.unroll(tile_n):
|
||||
acc[i_o] = T.warp_reduce_sum(acc[i_o])
|
||||
if i_t == 0:
|
||||
sqr[0] = T.warp_reduce_sum(sqr[0])
|
||||
|
||||
lane = tid % 32
|
||||
warp_id = tid // 32
|
||||
num_warps = n_thr // 32
|
||||
warp_acc = T.alloc_shared((num_warps, tile_n), T.float32)
|
||||
warp_sqr = T.alloc_shared(num_warps, T.float32)
|
||||
|
||||
if lane == 0:
|
||||
for i_o in T.unroll(tile_n):
|
||||
warp_acc[warp_id, i_o] = acc[i_o]
|
||||
if i_t == 0:
|
||||
warp_sqr[warp_id] = sqr[0]
|
||||
T.sync_threads()
|
||||
|
||||
if warp_id == 0:
|
||||
if lane < tile_n:
|
||||
reduced_acc = T.alloc_var(T.float32, init=0.0)
|
||||
for i_w in T.unroll(num_warps):
|
||||
reduced_acc += warp_acc[i_w, lane]
|
||||
out_idx = i_t * tile_n + lane
|
||||
if out_idx < n_out:
|
||||
out[i_s, i_n, out_idx] = reduced_acc
|
||||
if lane == 0 and i_t == 0:
|
||||
reduced_sqr = T.alloc_var(T.float32, init=0.0)
|
||||
for i_w in T.unroll(num_warps):
|
||||
reduced_sqr += warp_sqr[i_w]
|
||||
sqrsum[i_s, i_n] = reduced_sqr
|
||||
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def hc_prenorm_gemm_block_m_tilelang(
|
||||
x,
|
||||
fn,
|
||||
out,
|
||||
sqrsum,
|
||||
hidden_size: int,
|
||||
hc_mult: int = 4,
|
||||
n_out: int = 24,
|
||||
n_thr: int = 512,
|
||||
tile_n: int = 12,
|
||||
block_m: int = 2,
|
||||
) -> tilelang.JITKernel:
|
||||
num_tokens = T.dynamic("num_tokens")
|
||||
hc_hidden_size = hc_mult * hidden_size
|
||||
k_iters = hc_hidden_size // n_thr
|
||||
n_tiles = T.ceildiv(n_out, tile_n)
|
||||
m_tiles = T.ceildiv(num_tokens, block_m)
|
||||
|
||||
x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type]
|
||||
fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type]
|
||||
out: T.Tensor((1, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type]
|
||||
sqrsum: T.Tensor((1, num_tokens), T.float32) # type: ignore[no-redef, valid-type]
|
||||
|
||||
with T.Kernel(m_tiles, n_tiles, threads=n_thr) as (i_mt, i_t):
|
||||
tid = T.get_thread_binding()
|
||||
acc = T.alloc_local((block_m, tile_n), T.float32)
|
||||
sqr = T.alloc_local((block_m,), T.float32)
|
||||
T.clear(acc)
|
||||
T.clear(sqr)
|
||||
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
|
||||
for it in T.serial(k_iters):
|
||||
i_k = it * n_thr + tid
|
||||
fn_val = T.alloc_local((tile_n,), T.float32)
|
||||
for i_o in T.unroll(tile_n):
|
||||
out_idx = i_t * tile_n + i_o
|
||||
if out_idx < n_out:
|
||||
fn_val[i_o] = fn[out_idx, i_k]
|
||||
else:
|
||||
fn_val[i_o] = 0.0
|
||||
for i_m in T.unroll(block_m):
|
||||
token_idx = i_mt * block_m + i_m
|
||||
if token_idx < num_tokens:
|
||||
x_val = x[token_idx, i_k]
|
||||
for i_o in T.unroll(tile_n):
|
||||
acc[i_m, i_o] += x_val * fn_val[i_o]
|
||||
if i_t == 0:
|
||||
sqr[i_m] += x_val * x_val
|
||||
|
||||
for i_m in T.unroll(block_m):
|
||||
for i_o in T.unroll(tile_n):
|
||||
acc[i_m, i_o] = T.warp_reduce_sum(acc[i_m, i_o])
|
||||
if i_t == 0:
|
||||
sqr[i_m] = T.warp_reduce_sum(sqr[i_m])
|
||||
|
||||
lane = tid % 32
|
||||
warp_id = tid // 32
|
||||
num_warps = n_thr // 32
|
||||
warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32)
|
||||
warp_sqr = T.alloc_shared((num_warps, block_m), T.float32)
|
||||
|
||||
if lane == 0:
|
||||
for i_m in T.unroll(block_m):
|
||||
for i_o in T.unroll(tile_n):
|
||||
warp_acc[warp_id, i_m, i_o] = acc[i_m, i_o]
|
||||
if i_t == 0:
|
||||
warp_sqr[warp_id, i_m] = sqr[i_m]
|
||||
T.sync_threads()
|
||||
|
||||
if warp_id == 0:
|
||||
for i_m in T.unroll(block_m):
|
||||
token_idx = i_mt * block_m + i_m
|
||||
if token_idx < num_tokens:
|
||||
if lane < tile_n:
|
||||
reduced_acc = T.alloc_var(T.float32, init=0.0)
|
||||
for i_w in T.unroll(num_warps):
|
||||
reduced_acc += warp_acc[i_w, i_m, lane]
|
||||
out_idx = i_t * tile_n + lane
|
||||
if out_idx < n_out:
|
||||
out[0, token_idx, out_idx] = reduced_acc
|
||||
if lane == 0 and i_t == 0:
|
||||
reduced_sqr = T.alloc_var(T.float32, init=0.0)
|
||||
for i_w in T.unroll(num_warps):
|
||||
reduced_sqr += warp_sqr[i_w, i_m]
|
||||
sqrsum[0, token_idx] = reduced_sqr
|
||||
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
|
||||
@tilelang.jit(
|
||||
pass_configs=pass_configs,
|
||||
)
|
||||
def hc_head_fuse_tilelang(
|
||||
residual,
|
||||
@@ -566,7 +748,8 @@ def hc_head_fuse_tilelang(
|
||||
out: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef,valid-type]
|
||||
|
||||
with T.Kernel(num_tokens, threads=n_thr) as i:
|
||||
T.pdl_sync()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_sync()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pass 1 – for each residual channel m_c and h_block:
|
||||
@@ -624,4 +807,5 @@ def hc_head_fuse_tilelang(
|
||||
|
||||
T.copy(ol, out[i, i0_h * h_block], disable_tma=True)
|
||||
|
||||
T.pdl_trigger()
|
||||
if ENABLE_PDL:
|
||||
T.pdl_trigger()
|
||||
|
||||
@@ -5,6 +5,88 @@ import torch
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
def _torch_hc_prenorm_gemm(
|
||||
x: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
sqrsum: torch.Tensor,
|
||||
) -> None:
|
||||
assert out.shape[0] == 1
|
||||
assert sqrsum.shape[0] == 1
|
||||
x_float = x.float()
|
||||
out[0].copy_(x_float @ fn.t())
|
||||
sqrsum[0].copy_(x_float.square().sum(dim=-1))
|
||||
|
||||
|
||||
def _tilelang_hc_prenorm_gemm(
|
||||
x: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
sqrsum: torch.Tensor,
|
||||
hidden_size: int,
|
||||
hc_mult: int,
|
||||
tile_n: int = 12,
|
||||
n_thr: int = 512,
|
||||
n_splits: int = 1,
|
||||
) -> None:
|
||||
from vllm._tilelang_ops import (
|
||||
hc_prenorm_gemm_block_m_tilelang,
|
||||
hc_prenorm_gemm_tilelang,
|
||||
)
|
||||
|
||||
assert out.shape[0] == n_splits
|
||||
assert sqrsum.shape[0] == n_splits
|
||||
assert x.shape[1] == hc_mult * hidden_size
|
||||
assert x.shape[1] % n_splits == 0
|
||||
assert (x.shape[1] // n_splits) % n_thr == 0
|
||||
use_default_config = tile_n == 12 and n_thr == 512
|
||||
if n_splits == 1 and use_default_config and x.shape[0] >= 1024:
|
||||
hc_prenorm_gemm_block_m_tilelang(
|
||||
x,
|
||||
fn,
|
||||
out,
|
||||
sqrsum,
|
||||
hidden_size,
|
||||
hc_mult,
|
||||
fn.shape[0],
|
||||
n_thr,
|
||||
tile_n,
|
||||
2,
|
||||
)
|
||||
return
|
||||
if (
|
||||
n_splits == 1
|
||||
and use_default_config
|
||||
and x.shape[0] < 128
|
||||
and x.shape[1] % 1024 == 0
|
||||
):
|
||||
hc_prenorm_gemm_tilelang(
|
||||
x,
|
||||
fn,
|
||||
out,
|
||||
sqrsum,
|
||||
hidden_size,
|
||||
hc_mult,
|
||||
fn.shape[0],
|
||||
1024,
|
||||
4,
|
||||
n_splits,
|
||||
)
|
||||
return
|
||||
hc_prenorm_gemm_tilelang(
|
||||
x,
|
||||
fn,
|
||||
out,
|
||||
sqrsum,
|
||||
hidden_size,
|
||||
hc_mult,
|
||||
fn.shape[0],
|
||||
n_thr,
|
||||
tile_n,
|
||||
n_splits,
|
||||
)
|
||||
|
||||
|
||||
def mhc_pre_tilelang(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
@@ -80,10 +162,16 @@ def mhc_pre_tilelang(
|
||||
residual_flat = residual.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = residual_flat.shape[0]
|
||||
|
||||
# these numbers are from deepgemm kernel impl
|
||||
block_k = 64
|
||||
block_m = 64
|
||||
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
use_deep_gemm = is_deep_gemm_supported()
|
||||
if use_deep_gemm:
|
||||
# these numbers are from deepgemm kernel impl
|
||||
block_k = 64
|
||||
block_m = 64
|
||||
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
|
||||
else:
|
||||
n_splits = 1
|
||||
|
||||
post_mix = torch.empty(
|
||||
num_tokens, hc_mult, dtype=torch.float32, device=residual.device
|
||||
@@ -102,13 +190,24 @@ def mhc_pre_tilelang(
|
||||
n_splits, num_tokens, dtype=torch.float32, device=residual.device
|
||||
)
|
||||
|
||||
tf32_hc_prenorm_gemm(
|
||||
residual_flat.view(num_tokens, hc_mult * hidden_size),
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
n_splits,
|
||||
)
|
||||
residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size)
|
||||
if use_deep_gemm:
|
||||
tf32_hc_prenorm_gemm(
|
||||
residual_2d,
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
n_splits,
|
||||
)
|
||||
else:
|
||||
_tilelang_hc_prenorm_gemm(
|
||||
residual_2d,
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
hidden_size,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
if norm_weight is None:
|
||||
mhc_pre_big_fuse_tilelang(
|
||||
@@ -304,16 +403,24 @@ def mhc_fused_post_pre_tilelang(
|
||||
post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult)
|
||||
comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult)
|
||||
|
||||
fma_token_threshold = 16
|
||||
if num_tokens <= fma_token_threshold:
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
|
||||
use_deep_gemm = is_deep_gemm_supported()
|
||||
use_small_fma = num_tokens <= 16
|
||||
if use_small_fma:
|
||||
# TODO(gnovack): investigate autotuning these heuristics
|
||||
tile_n = 2 if num_tokens < 8 else 3
|
||||
n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4
|
||||
else:
|
||||
# these number are from deepgemm kernel impl
|
||||
block_k = 64
|
||||
block_m = 64
|
||||
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
|
||||
if use_deep_gemm:
|
||||
# these number are from deepgemm kernel impl
|
||||
block_k = 64
|
||||
block_m = 64
|
||||
n_splits = compute_num_split(
|
||||
block_k, hc_hidden_size, cdiv(num_tokens, block_m)
|
||||
)
|
||||
else:
|
||||
n_splits = 1
|
||||
|
||||
gemm_out_mul = torch.empty(
|
||||
n_splits,
|
||||
@@ -348,7 +455,7 @@ def mhc_fused_post_pre_tilelang(
|
||||
device=residual.device,
|
||||
)
|
||||
|
||||
if num_tokens <= fma_token_threshold:
|
||||
if use_small_fma:
|
||||
mhc_fused_tilelang(
|
||||
comb_res_mix_flat,
|
||||
residual_flat,
|
||||
@@ -375,15 +482,26 @@ def mhc_fused_post_pre_tilelang(
|
||||
residual.shape[-1],
|
||||
)
|
||||
|
||||
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
|
||||
residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size)
|
||||
if use_deep_gemm:
|
||||
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
|
||||
|
||||
tf32_hc_prenorm_gemm(
|
||||
residual_cur.view(num_tokens, hc_mult * hidden_size),
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
n_splits,
|
||||
)
|
||||
tf32_hc_prenorm_gemm(
|
||||
residual_cur_2d,
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
n_splits,
|
||||
)
|
||||
else:
|
||||
_tilelang_hc_prenorm_gemm(
|
||||
residual_cur_2d,
|
||||
fn,
|
||||
gemm_out_mul,
|
||||
gemm_out_sqrsum,
|
||||
hidden_size,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
if norm_weight is None:
|
||||
mhc_pre_big_fuse_tilelang(
|
||||
|
||||
@@ -3,8 +3,12 @@
|
||||
import torch
|
||||
|
||||
# this import will also register the custom ops
|
||||
# import vllm.model_executor.kernels.mhc # noqa: F401
|
||||
import vllm.model_executor.kernels.mhc as mhc_kernels
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
|
||||
HAS_TILELANG = has_tilelang()
|
||||
|
||||
|
||||
# --8<-- [start:mhc_pre]
|
||||
@@ -85,6 +89,52 @@ class MHCPreOp(CustomOp):
|
||||
# sinkhorn_repeat,
|
||||
# )
|
||||
# else:
|
||||
if HAS_TILELANG:
|
||||
return torch.ops.vllm.mhc_pre_tilelang(
|
||||
residual,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_mult_value,
|
||||
sinkhorn_repeat,
|
||||
n_splits,
|
||||
norm_weight,
|
||||
norm_eps,
|
||||
)
|
||||
else:
|
||||
return self.forward_native(
|
||||
residual,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_mult_value,
|
||||
sinkhorn_repeat,
|
||||
n_splits,
|
||||
norm_weight,
|
||||
norm_eps,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
norm_weight: torch.Tensor | None = None,
|
||||
norm_eps: float = 0.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return mhc_kernels.mhc_pre_torch(
|
||||
residual,
|
||||
fn,
|
||||
@@ -97,9 +147,6 @@ class MHCPreOp(CustomOp):
|
||||
sinkhorn_repeat,
|
||||
)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
raise NotImplementedError("Native implementation of mhc_pre is not available")
|
||||
|
||||
|
||||
# --8<-- [start:mhc_post]
|
||||
@CustomOp.register("mhc_post")
|
||||
@@ -147,6 +194,20 @@ class MHCPostOp(CustomOp):
|
||||
# comb_res_mix,
|
||||
# )
|
||||
# else:
|
||||
if HAS_TILELANG:
|
||||
return torch.ops.vllm.mhc_post_tilelang(
|
||||
x, residual, post_layer_mix, comb_res_mix
|
||||
)
|
||||
else:
|
||||
return self.forward_native(x, residual, post_layer_mix, comb_res_mix)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return mhc_kernels.mhc_post_torch(
|
||||
x,
|
||||
residual,
|
||||
@@ -154,9 +215,6 @@ class MHCPostOp(CustomOp):
|
||||
comb_res_mix,
|
||||
)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
raise NotImplementedError("Native implementation of mhc_post is not available")
|
||||
|
||||
|
||||
# --8<-- [start:hc_head]
|
||||
@CustomOp.register("hc_head")
|
||||
@@ -220,17 +278,32 @@ class HCHeadOp(CustomOp):
|
||||
out = torch.empty(
|
||||
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
|
||||
)
|
||||
torch.ops.vllm.hc_head_triton(
|
||||
hs_flat,
|
||||
hc_fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_norm_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
if HAS_TILELANG:
|
||||
torch.ops.vllm.hc_head_fused_kernel_tilelang(
|
||||
hs_flat,
|
||||
hc_fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_norm_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.hc_head_triton(
|
||||
hs_flat,
|
||||
hc_fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
out,
|
||||
hidden_size,
|
||||
rms_norm_eps,
|
||||
hc_eps,
|
||||
hc_mult,
|
||||
)
|
||||
|
||||
return out.view(*outer_shape, hidden_size)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
@@ -290,9 +363,42 @@ class MHCFusedPostPreOp(CustomOp):
|
||||
norm_eps,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Hip implementation of mhc_fused_post_pre is not available"
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
tile_n: int = 1,
|
||||
norm_weight: torch.Tensor | None = None,
|
||||
norm_eps: float = 0.0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.mhc_fused_post_pre_tilelang(
|
||||
x,
|
||||
residual,
|
||||
post_layer_mix,
|
||||
comb_res_mix,
|
||||
fn,
|
||||
hc_scale,
|
||||
hc_base,
|
||||
rms_eps,
|
||||
hc_pre_eps,
|
||||
hc_sinkhorn_eps,
|
||||
hc_post_mult_value,
|
||||
sinkhorn_repeat,
|
||||
n_splits,
|
||||
tile_n,
|
||||
norm_weight,
|
||||
norm_eps,
|
||||
)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
|
||||
@@ -54,6 +54,7 @@ from vllm.models.deepseek_v4.attention import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
|
||||
|
||||
class DeepseekV4MLP(nn.Module):
|
||||
@@ -473,6 +474,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
self.mhc_pre = MHCPreOp()
|
||||
self.mhc_post = MHCPostOp()
|
||||
self.mhc_fused_post_pre = MHCFusedPostPreOp()
|
||||
self.has_tilelang = has_tilelang()
|
||||
|
||||
def hc_pre(
|
||||
self,
|
||||
@@ -503,7 +505,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
):
|
||||
return self.mhc_post(x, residual, post, comb)
|
||||
|
||||
def _forward_cuda(
|
||||
def _forward_fused_post_pre(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -555,7 +557,7 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
x = self.ffn(x, input_ids)
|
||||
return x, residual, post_mix, res_mix
|
||||
|
||||
def _forward_rocm(
|
||||
def _forward_unfused_post_pre(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -594,12 +596,13 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
) -> tuple[
|
||||
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
|
||||
]:
|
||||
if current_platform.is_rocm():
|
||||
return self._forward_rocm(
|
||||
if not self.has_tilelang:
|
||||
return self._forward_unfused_post_pre(
|
||||
x, positions, input_ids, post_mix, res_mix, residual
|
||||
)
|
||||
|
||||
return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual)
|
||||
return self._forward_fused_post_pre(
|
||||
x, positions, input_ids, post_mix, res_mix, residual
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
@@ -682,6 +685,7 @@ class DeepseekV4Model(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
self.hc_head_op = HCHeadOp()
|
||||
self.has_tilelang = has_tilelang()
|
||||
# Pre-hc_head residual stream buffer for the MTP draft. Stable
|
||||
# address (outside the cudagraph pool) so the copy_ in forward()
|
||||
# refreshes it correctly across captured shapes.
|
||||
@@ -748,7 +752,7 @@ class DeepseekV4Model(nn.Module):
|
||||
res_mix,
|
||||
residual,
|
||||
)
|
||||
if layer is not None and current_platform.is_cuda():
|
||||
if layer is not None and self.has_tilelang:
|
||||
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
|
||||
@@ -39,6 +39,7 @@ from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weigh
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import has_tilelang
|
||||
|
||||
from .model import DeepseekV4DecoderLayer
|
||||
|
||||
@@ -118,6 +119,7 @@ class DeepSeekV4MultiTokenPredictorLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.hc_head_op = HCHeadOp()
|
||||
self.has_tilelang = has_tilelang()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -144,7 +146,7 @@ class DeepSeekV4MultiTokenPredictorLayer(nn.Module):
|
||||
hidden_states, residual, post_mix, res_mix = self.mtp_block(
|
||||
positions=positions, x=hidden_states, input_ids=None
|
||||
)
|
||||
if current_platform.is_cuda():
|
||||
if self.has_tilelang:
|
||||
hidden_states = self.mtp_block.hc_post(
|
||||
hidden_states, residual, post_mix, res_mix
|
||||
)
|
||||
|
||||
@@ -592,6 +592,15 @@ class CudaPlatformBase(Platform):
|
||||
default, rms_norm=rms_norm, fused_add_rms_norm=rms_norm
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_arch_support_pdl(cls) -> bool:
|
||||
try:
|
||||
device = torch.cuda.current_device()
|
||||
major, _ = torch.cuda.get_device_capability(device)
|
||||
except Exception:
|
||||
return False
|
||||
return major >= 9
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
||||
@@ -1016,6 +1016,13 @@ class Platform:
|
||||
# Native always used by default. Platforms can override this behavior.
|
||||
return IrOpPriorityConfig.with_default(["native"])
|
||||
|
||||
@classmethod
|
||||
def is_arch_support_pdl(cls) -> bool:
|
||||
"""
|
||||
Does the current platform support PDL (Programmatic Dependent Launch)?
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@@ -430,6 +430,7 @@ def has_triton_kernels() -> bool:
|
||||
return is_available
|
||||
|
||||
|
||||
@cache
|
||||
def has_tilelang() -> bool:
|
||||
"""Whether the optional `tilelang` package is available."""
|
||||
return _has_module("tilelang")
|
||||
|
||||
Reference in New Issue
Block a user