[ROCm][DSV4] Enable Tilelang MHC replacing torch/triton mhc (#43679)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2026-05-28 15:05:28 +08:00
committed by GitHub
parent e1814f822d
commit 0ba46d4b11
13 changed files with 715 additions and 98 deletions
+1
View File
@@ -16,3 +16,4 @@ wheel
jinja2>=3.1.6
amdsmi==7.0.2
timm>=1.0.17
tilelang==0.1.10
+1
View File
@@ -22,3 +22,4 @@ timm>=1.0.17
# amd-quark: required for Quark quantization on ROCm
# To be consistent with test_quark.py
amd-quark>=0.8.99
tilelang==0.1.10
+1
View File
@@ -43,6 +43,7 @@ schemathesis>=3.39.15 # Required for openai schema test
# quantization
bitsandbytes==0.49.2
buildkite-test-collector==0.1.9
tilelang==0.1.10
genai_perf>=0.0.8
tritonclient>=2.51.0
+21 -2
View File
@@ -43,7 +43,9 @@ anyio==4.13.0
# starlette
# watchfiles
apache-tvm-ffi==0.1.10
# via xgrammar
# via
# tilelang
# xgrammar
arctic-inference==0.1.1
# via -r requirements/test/rocm.in
argcomplete==3.6.3
@@ -129,7 +131,9 @@ click==8.3.1
# typer
# uvicorn
cloudpickle==3.1.2
# via -r requirements/test/../common.txt
# via
# -r requirements/test/../common.txt
# tilelang
colorama==0.4.6
# via
# perceptron
@@ -511,6 +515,8 @@ mistral-common==1.11.2
# -c requirements/common.txt
# -r requirements/test/../common.txt
# -r requirements/test/rocm.in
ml-dtypes==0.5.4
# via tilelang
model-hosting-container-standards==0.1.14
# via
# -c requirements/common.txt
@@ -587,6 +593,7 @@ numpy==2.2.6
# lm-eval
# matplotlib
# mistral-common
# ml-dtypes
# mteb
# numba
# opencv-python-headless
@@ -610,6 +617,7 @@ numpy==2.2.6
# statsmodels
# tensorizer
# tifffile
# tilelang
# torchvision
# transformers
# tritonclient
@@ -811,6 +819,7 @@ psutil==7.2.2
# accelerate
# peft
# tensorizer
# tilelang
py==1.11.0
# via pytest-forked
py-cpuinfo==9.0.0
@@ -1192,6 +1201,10 @@ tiktoken==0.12.0
# gpt-oss
# lm-eval
# mistral-common
tilelang==0.1.10
# via
# -c requirements/rocm.txt
# -r requirements/test/rocm.in
timm==1.0.17
# via
# -c requirements/rocm.txt
@@ -1208,6 +1221,8 @@ tomli==2.4.0
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch-c-dlpack-ext==0.1.5
# via tilelang
tqdm==4.67.3
# via
# -r requirements/test/../common.txt
@@ -1225,6 +1240,7 @@ tqdm==4.67.3
# pqdm
# segmentation-models-pytorch
# sentence-transformers
# tilelang
# transformers
transformers==5.5.3
# via
@@ -1293,6 +1309,7 @@ typing-extensions==4.15.0
# sentence-transformers
# sqlalchemy
# starlette
# tilelang
# torch
# typeguard
# typing-inspection
@@ -1359,6 +1376,8 @@ yarl==1.23.0
# via
# aiohttp
# schemathesis
z3-solver==4.15.4.0
# via tilelang
zipp==3.23.0
# via importlib-metadata
+166 -2
View File
@@ -4,7 +4,12 @@ import pytest
import torch
import vllm.model_executor.kernels.mhc # noqa: F401
from vllm.model_executor.kernels.mhc.tilelang import (
_tilelang_hc_prenorm_gemm,
_torch_hc_prenorm_gemm,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.torch_utils import set_random_seed
DEVICE = current_platform.device_type
@@ -92,8 +97,128 @@ def hc_head_ref(
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="CUDA required",
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_pre_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
hc_mult2 = hc_mult * hc_mult
hc_mult3 = 2 * hc_mult + hc_mult2
fn = (
torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float)
* 1e-4
* (1 + torch.arange(hc_mult).mul(0.01).view(1, -1, 1))
).flatten(1, 2)
hc_scale = torch.randn((3,), dtype=torch.float) * 0.1
hc_base = torch.randn((hc_mult3,), dtype=torch.float) * 0.1
hc_sinkhorn_eps = hc_pre_eps = rms_eps = 1e-6
sinkhorn_repeat = 20
hc_post_alpha = 1.0
ref = mhc_pre_ref(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)
out = torch.ops.vllm.mhc_pre_tilelang(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_alpha,
sinkhorn_repeat,
)
for actual, expected in zip(out, ref, strict=True):
torch.testing.assert_close(actual, expected, atol=5e-2, rtol=1e-2)
@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize(
("num_tokens", "hidden_size"),
[
(1, 1280),
(512, 1280),
(2048, 1280),
(1, 4096),
(64, 4096),
(512, 4096),
(2048, 4096),
(1, 7168),
(64, 7168),
(512, 7168),
(2048, 7168),
],
)
def test_hc_prenorm_gemm_tilelang(num_tokens, hidden_size):
torch.set_default_device(DEVICE)
set_random_seed(0)
hc_mult = 4
hc_mult3 = 2 * hc_mult + hc_mult * hc_mult
x = torch.randn((num_tokens, hc_mult * hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult3, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
out_ref = torch.empty((1, num_tokens, hc_mult3), dtype=torch.float32)
sqrsum_ref = torch.empty((1, num_tokens), dtype=torch.float32)
out = torch.empty_like(out_ref)
sqrsum = torch.empty_like(sqrsum_ref)
_torch_hc_prenorm_gemm(x, fn, out_ref, sqrsum_ref)
_tilelang_hc_prenorm_gemm(x, fn, out, sqrsum, hidden_size, hc_mult)
torch.testing.assert_close(out, out_ref, atol=1e-5, rtol=1e-4)
torch.testing.assert_close(sqrsum, sqrsum_ref, atol=8.0, rtol=5e-4)
@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_mhc_post_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)
x = torch.randn((num_tokens, hidden_size), dtype=torch.bfloat16)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
post_layer_mix = torch.randn((num_tokens, hc_mult, 1), dtype=torch.float32)
comb_res_mix = torch.randn((num_tokens, hc_mult, hc_mult), dtype=torch.float32)
ref = mhc_post_ref(x, residual, post_layer_mix, comb_res_mix)
out = torch.ops.vllm.mhc_post_tilelang(
x,
residual,
post_layer_mix,
comb_res_mix,
)
torch.testing.assert_close(out, ref, atol=5e-2, rtol=1e-2)
@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@@ -196,3 +321,42 @@ def test_hc_head_triton(num_tokens, hidden_size, hc_mult):
out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
@pytest.mark.skipif(
not (current_platform.is_cuda_alike() and has_tilelang()),
reason="CUDA or ROCm and tilelang required",
)
@pytest.mark.parametrize("num_tokens", [1, 4, 8, 128])
@pytest.mark.parametrize("hidden_size", [4096, 7168])
@pytest.mark.parametrize("hc_mult", [4])
def test_hc_head_tilelang(num_tokens, hidden_size, hc_mult):
torch.set_default_device(DEVICE)
set_random_seed(0)
residual = torch.randn((num_tokens, hc_mult, hidden_size), dtype=torch.bfloat16)
fn = torch.randn((hc_mult, hc_mult * hidden_size), dtype=torch.float32) * 1e-4
hc_scale = torch.randn((1,), dtype=torch.float32) * 0.1
hc_base = torch.randn((hc_mult,), dtype=torch.float32) * 0.1
rms_eps = hc_eps = 1e-6
out = torch.empty((num_tokens, hidden_size), dtype=torch.bfloat16)
out.fill_(float("nan"))
result = torch.ops.vllm.hc_head_fused_kernel_tilelang(
residual,
fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_eps,
hc_eps,
hc_mult,
)
assert result is None
assert not torch.isnan(out).any()
out_ref = hc_head_ref(residual, fn, hc_scale, hc_base, rms_eps, hc_eps)
torch.testing.assert_close(out, out_ref, atol=5e-2, rtol=1e-2)
+224 -40
View File
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from functools import cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import torch
@@ -10,8 +10,9 @@ from vllm.platforms import current_platform
from vllm.utils.import_utils import has_tilelang
from vllm.utils.math_utils import cdiv
# tilelang is only available on CUDA platforms
if TYPE_CHECKING or current_platform.is_cuda():
# TileLang is used for MHC on CUDA and ROCm. Keep non-GPU imports cheap so
# registering the Python wrapper modules does not require TileLang everywhere.
if TYPE_CHECKING or current_platform.is_cuda_alike():
if not has_tilelang():
raise ImportError(
"tilelang is required for mhc but is not installed. Install it with "
@@ -23,6 +24,8 @@ else:
tilelang = None # type: ignore[assignment]
T = None # type: ignore[assignment]
ENABLE_PDL = current_platform.is_arch_support_pdl() and current_platform.is_cuda()
@cache
def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
@@ -37,12 +40,17 @@ def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int:
return split_k
pass_configs: dict[tilelang.PassConfigKey, Any] = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
}
if current_platform.is_cuda():
pass_configs[tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL] = 10
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
pass_configs=pass_configs,
)
def mhc_pre_big_fuse_tilelang(
gemm_out_mul,
@@ -78,7 +86,8 @@ def mhc_pre_big_fuse_tilelang(
layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type]
with T.Kernel(num_tokens, threads=96) as i:
T.pdl_sync()
if ENABLE_PDL:
T.pdl_sync()
##################################################################
# _pre_norm_fn_fwd_norm
rms = T.alloc_fragment(1, T.float32)
@@ -174,18 +183,16 @@ def mhc_pre_big_fuse_tilelang(
ol[i1_h] += pre * xl[i_hc, i1_h]
T.copy(ol, layer_input[i, i0_h * hidden_block])
T.pdl_trigger()
if ENABLE_PDL:
T.pdl_trigger()
# Copied from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/mhc.py#L478
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
pass_configs=pass_configs,
)
def mhc_pre_big_fuse_with_norm_tilelang(
gemm_out_mul,
@@ -230,7 +237,8 @@ def mhc_pre_big_fuse_with_norm_tilelang(
T.clear(mixes)
rms[0] = 0
T.pdl_sync()
if ENABLE_PDL:
T.pdl_sync()
for i_split in T.serial(n_splits):
rms[0] += gemm_out_sqrsum[i_split, i]
@@ -341,15 +349,12 @@ def mhc_pre_big_fuse_with_norm_tilelang(
T.copy(ol, layer_input[i, i0_h * hidden_block])
T.pdl_trigger()
if ENABLE_PDL:
T.pdl_trigger()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
pass_configs=pass_configs,
)
def mhc_fused_tilelang(
comb_mix,
@@ -390,8 +395,8 @@ def mhc_fused_tilelang(
with T.Kernel(m, n_tiles, split_k, threads=n_thr) as (i_n, i_nt, i_ks):
tid = T.get_thread_binding()
warp_id = T.get_warp_idx()
lane = T.get_lane_idx()
warp_id = tid // 32
lane = tid % 32
s_warp = T.alloc_shared((num_warps, tile_n + 1), T.float32)
s_post = T.alloc_shared((hc,), T.float32)
@@ -407,7 +412,8 @@ def mhc_fused_tilelang(
T.clear(sqr)
h_split_start = i_ks * h_per_split
T.pdl_sync()
if ENABLE_PDL:
T.pdl_sync()
T.copy(post_mix[i_n, 0], s_post)
T.copy(comb_mix[i_n, 0, 0], s_comb)
@@ -466,15 +472,12 @@ def mhc_fused_tilelang(
v2 += s_warp[w, tile_n]
rp_out[i_ks, i_n] = v2
T.pdl_trigger()
if ENABLE_PDL:
T.pdl_trigger()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
pass_configs=pass_configs,
)
def mhc_post_tilelang(
a,
@@ -507,7 +510,8 @@ def mhc_post_tilelang(
a_local = T.alloc_fragment((hc, hc), T.float32)
c_local = T.alloc_fragment(hc, T.float32)
T.pdl_sync()
if ENABLE_PDL:
T.pdl_sync()
T.copy(a[i_n, 0, 0], a_local)
T.copy(c[i_n, 0], c_local)
@@ -523,15 +527,193 @@ def mhc_post_tilelang(
x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h]
T.copy(x_local, x[i_n, 0, i0_h * h_blk])
T.pdl_trigger()
if ENABLE_PDL:
T.pdl_trigger()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10,
},
pass_configs=pass_configs,
)
def hc_prenorm_gemm_tilelang(
x,
fn,
out,
sqrsum,
hidden_size: int,
hc_mult: int = 4,
n_out: int = 24,
n_thr: int = 512,
tile_n: int = 12,
n_splits: int = 1,
) -> tilelang.JITKernel:
num_tokens = T.dynamic("num_tokens")
hc_hidden_size = hc_mult * hidden_size
k_per_split = hc_hidden_size // n_splits
k_iters = k_per_split // n_thr
n_tiles = T.ceildiv(n_out, tile_n)
x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type]
fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type]
out: T.Tensor((n_splits, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type]
sqrsum: T.Tensor((n_splits, num_tokens), T.float32) # type: ignore[no-redef, valid-type]
with T.Kernel(num_tokens, n_tiles, n_splits, threads=n_thr) as (
i_n,
i_t,
i_s,
):
tid = T.get_thread_binding()
acc = T.alloc_local((tile_n,), T.float32)
sqr = T.alloc_local((1,), T.float32)
T.clear(acc)
T.clear(sqr)
if ENABLE_PDL:
T.pdl_sync()
for it in T.serial(k_iters):
i_k = i_s * k_per_split + it * n_thr + tid
x_val = x[i_n, i_k]
for i_o in T.unroll(tile_n):
out_idx = i_t * tile_n + i_o
if out_idx < n_out:
acc[i_o] += x_val * fn[out_idx, i_k]
if i_t == 0:
sqr[0] += x_val * x_val
for i_o in T.unroll(tile_n):
acc[i_o] = T.warp_reduce_sum(acc[i_o])
if i_t == 0:
sqr[0] = T.warp_reduce_sum(sqr[0])
lane = tid % 32
warp_id = tid // 32
num_warps = n_thr // 32
warp_acc = T.alloc_shared((num_warps, tile_n), T.float32)
warp_sqr = T.alloc_shared(num_warps, T.float32)
if lane == 0:
for i_o in T.unroll(tile_n):
warp_acc[warp_id, i_o] = acc[i_o]
if i_t == 0:
warp_sqr[warp_id] = sqr[0]
T.sync_threads()
if warp_id == 0:
if lane < tile_n:
reduced_acc = T.alloc_var(T.float32, init=0.0)
for i_w in T.unroll(num_warps):
reduced_acc += warp_acc[i_w, lane]
out_idx = i_t * tile_n + lane
if out_idx < n_out:
out[i_s, i_n, out_idx] = reduced_acc
if lane == 0 and i_t == 0:
reduced_sqr = T.alloc_var(T.float32, init=0.0)
for i_w in T.unroll(num_warps):
reduced_sqr += warp_sqr[i_w]
sqrsum[i_s, i_n] = reduced_sqr
if ENABLE_PDL:
T.pdl_trigger()
@tilelang.jit(
pass_configs=pass_configs,
)
def hc_prenorm_gemm_block_m_tilelang(
x,
fn,
out,
sqrsum,
hidden_size: int,
hc_mult: int = 4,
n_out: int = 24,
n_thr: int = 512,
tile_n: int = 12,
block_m: int = 2,
) -> tilelang.JITKernel:
num_tokens = T.dynamic("num_tokens")
hc_hidden_size = hc_mult * hidden_size
k_iters = hc_hidden_size // n_thr
n_tiles = T.ceildiv(n_out, tile_n)
m_tiles = T.ceildiv(num_tokens, block_m)
x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) # type: ignore[no-redef, valid-type]
fn: T.Tensor((n_out, hc_hidden_size), T.float32) # type: ignore[no-redef, valid-type]
out: T.Tensor((1, num_tokens, n_out), T.float32) # type: ignore[no-redef, valid-type]
sqrsum: T.Tensor((1, num_tokens), T.float32) # type: ignore[no-redef, valid-type]
with T.Kernel(m_tiles, n_tiles, threads=n_thr) as (i_mt, i_t):
tid = T.get_thread_binding()
acc = T.alloc_local((block_m, tile_n), T.float32)
sqr = T.alloc_local((block_m,), T.float32)
T.clear(acc)
T.clear(sqr)
if ENABLE_PDL:
T.pdl_sync()
for it in T.serial(k_iters):
i_k = it * n_thr + tid
fn_val = T.alloc_local((tile_n,), T.float32)
for i_o in T.unroll(tile_n):
out_idx = i_t * tile_n + i_o
if out_idx < n_out:
fn_val[i_o] = fn[out_idx, i_k]
else:
fn_val[i_o] = 0.0
for i_m in T.unroll(block_m):
token_idx = i_mt * block_m + i_m
if token_idx < num_tokens:
x_val = x[token_idx, i_k]
for i_o in T.unroll(tile_n):
acc[i_m, i_o] += x_val * fn_val[i_o]
if i_t == 0:
sqr[i_m] += x_val * x_val
for i_m in T.unroll(block_m):
for i_o in T.unroll(tile_n):
acc[i_m, i_o] = T.warp_reduce_sum(acc[i_m, i_o])
if i_t == 0:
sqr[i_m] = T.warp_reduce_sum(sqr[i_m])
lane = tid % 32
warp_id = tid // 32
num_warps = n_thr // 32
warp_acc = T.alloc_shared((num_warps, block_m, tile_n), T.float32)
warp_sqr = T.alloc_shared((num_warps, block_m), T.float32)
if lane == 0:
for i_m in T.unroll(block_m):
for i_o in T.unroll(tile_n):
warp_acc[warp_id, i_m, i_o] = acc[i_m, i_o]
if i_t == 0:
warp_sqr[warp_id, i_m] = sqr[i_m]
T.sync_threads()
if warp_id == 0:
for i_m in T.unroll(block_m):
token_idx = i_mt * block_m + i_m
if token_idx < num_tokens:
if lane < tile_n:
reduced_acc = T.alloc_var(T.float32, init=0.0)
for i_w in T.unroll(num_warps):
reduced_acc += warp_acc[i_w, i_m, lane]
out_idx = i_t * tile_n + lane
if out_idx < n_out:
out[0, token_idx, out_idx] = reduced_acc
if lane == 0 and i_t == 0:
reduced_sqr = T.alloc_var(T.float32, init=0.0)
for i_w in T.unroll(num_warps):
reduced_sqr += warp_sqr[i_w, i_m]
sqrsum[0, token_idx] = reduced_sqr
if ENABLE_PDL:
T.pdl_trigger()
@tilelang.jit(
pass_configs=pass_configs,
)
def hc_head_fuse_tilelang(
residual,
@@ -566,7 +748,8 @@ def hc_head_fuse_tilelang(
out: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef,valid-type]
with T.Kernel(num_tokens, threads=n_thr) as i:
T.pdl_sync()
if ENABLE_PDL:
T.pdl_sync()
# ------------------------------------------------------------------
# Pass 1 for each residual channel m_c and h_block:
@@ -624,4 +807,5 @@ def hc_head_fuse_tilelang(
T.copy(ol, out[i, i0_h * h_block], disable_tma=True)
T.pdl_trigger()
if ENABLE_PDL:
T.pdl_trigger()
+144 -26
View File
@@ -5,6 +5,88 @@ import torch
from vllm.utils.torch_utils import direct_register_custom_op
def _torch_hc_prenorm_gemm(
x: torch.Tensor,
fn: torch.Tensor,
out: torch.Tensor,
sqrsum: torch.Tensor,
) -> None:
assert out.shape[0] == 1
assert sqrsum.shape[0] == 1
x_float = x.float()
out[0].copy_(x_float @ fn.t())
sqrsum[0].copy_(x_float.square().sum(dim=-1))
def _tilelang_hc_prenorm_gemm(
x: torch.Tensor,
fn: torch.Tensor,
out: torch.Tensor,
sqrsum: torch.Tensor,
hidden_size: int,
hc_mult: int,
tile_n: int = 12,
n_thr: int = 512,
n_splits: int = 1,
) -> None:
from vllm._tilelang_ops import (
hc_prenorm_gemm_block_m_tilelang,
hc_prenorm_gemm_tilelang,
)
assert out.shape[0] == n_splits
assert sqrsum.shape[0] == n_splits
assert x.shape[1] == hc_mult * hidden_size
assert x.shape[1] % n_splits == 0
assert (x.shape[1] // n_splits) % n_thr == 0
use_default_config = tile_n == 12 and n_thr == 512
if n_splits == 1 and use_default_config and x.shape[0] >= 1024:
hc_prenorm_gemm_block_m_tilelang(
x,
fn,
out,
sqrsum,
hidden_size,
hc_mult,
fn.shape[0],
n_thr,
tile_n,
2,
)
return
if (
n_splits == 1
and use_default_config
and x.shape[0] < 128
and x.shape[1] % 1024 == 0
):
hc_prenorm_gemm_tilelang(
x,
fn,
out,
sqrsum,
hidden_size,
hc_mult,
fn.shape[0],
1024,
4,
n_splits,
)
return
hc_prenorm_gemm_tilelang(
x,
fn,
out,
sqrsum,
hidden_size,
hc_mult,
fn.shape[0],
n_thr,
tile_n,
n_splits,
)
def mhc_pre_tilelang(
residual: torch.Tensor,
fn: torch.Tensor,
@@ -80,10 +162,16 @@ def mhc_pre_tilelang(
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
# these numbers are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
from vllm.utils.deep_gemm import is_deep_gemm_supported
use_deep_gemm = is_deep_gemm_supported()
if use_deep_gemm:
# these numbers are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
else:
n_splits = 1
post_mix = torch.empty(
num_tokens, hc_mult, dtype=torch.float32, device=residual.device
@@ -102,13 +190,24 @@ def mhc_pre_tilelang(
n_splits, num_tokens, dtype=torch.float32, device=residual.device
)
tf32_hc_prenorm_gemm(
residual_flat.view(num_tokens, hc_mult * hidden_size),
fn,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size)
if use_deep_gemm:
tf32_hc_prenorm_gemm(
residual_2d,
fn,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
else:
_tilelang_hc_prenorm_gemm(
residual_2d,
fn,
gemm_out_mul,
gemm_out_sqrsum,
hidden_size,
hc_mult,
)
if norm_weight is None:
mhc_pre_big_fuse_tilelang(
@@ -304,16 +403,24 @@ def mhc_fused_post_pre_tilelang(
post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult)
comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult)
fma_token_threshold = 16
if num_tokens <= fma_token_threshold:
from vllm.utils.deep_gemm import is_deep_gemm_supported
use_deep_gemm = is_deep_gemm_supported()
use_small_fma = num_tokens <= 16
if use_small_fma:
# TODO(gnovack): investigate autotuning these heuristics
tile_n = 2 if num_tokens < 8 else 3
n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4
else:
# these number are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m))
if use_deep_gemm:
# these number are from deepgemm kernel impl
block_k = 64
block_m = 64
n_splits = compute_num_split(
block_k, hc_hidden_size, cdiv(num_tokens, block_m)
)
else:
n_splits = 1
gemm_out_mul = torch.empty(
n_splits,
@@ -348,7 +455,7 @@ def mhc_fused_post_pre_tilelang(
device=residual.device,
)
if num_tokens <= fma_token_threshold:
if use_small_fma:
mhc_fused_tilelang(
comb_res_mix_flat,
residual_flat,
@@ -375,15 +482,26 @@ def mhc_fused_post_pre_tilelang(
residual.shape[-1],
)
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size)
if use_deep_gemm:
from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm
tf32_hc_prenorm_gemm(
residual_cur.view(num_tokens, hc_mult * hidden_size),
fn,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
tf32_hc_prenorm_gemm(
residual_cur_2d,
fn,
gemm_out_mul,
gemm_out_sqrsum,
n_splits,
)
else:
_tilelang_hc_prenorm_gemm(
residual_cur_2d,
fn,
gemm_out_mul,
gemm_out_sqrsum,
hidden_size,
hc_mult,
)
if norm_weight is None:
mhc_pre_big_fuse_tilelang(
+126 -20
View File
@@ -3,8 +3,12 @@
import torch
# this import will also register the custom ops
# import vllm.model_executor.kernels.mhc # noqa: F401
import vllm.model_executor.kernels.mhc as mhc_kernels
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.import_utils import has_tilelang
HAS_TILELANG = has_tilelang()
# --8<-- [start:mhc_pre]
@@ -85,6 +89,52 @@ class MHCPreOp(CustomOp):
# sinkhorn_repeat,
# )
# else:
if HAS_TILELANG:
return torch.ops.vllm.mhc_pre_tilelang(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
norm_weight,
norm_eps,
)
else:
return self.forward_native(
residual,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
norm_weight,
norm_eps,
)
def forward_native(
self,
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
norm_weight: torch.Tensor | None = None,
norm_eps: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return mhc_kernels.mhc_pre_torch(
residual,
fn,
@@ -97,9 +147,6 @@ class MHCPreOp(CustomOp):
sinkhorn_repeat,
)
def forward_native(self, *args, **kwargs):
raise NotImplementedError("Native implementation of mhc_pre is not available")
# --8<-- [start:mhc_post]
@CustomOp.register("mhc_post")
@@ -147,6 +194,20 @@ class MHCPostOp(CustomOp):
# comb_res_mix,
# )
# else:
if HAS_TILELANG:
return torch.ops.vllm.mhc_post_tilelang(
x, residual, post_layer_mix, comb_res_mix
)
else:
return self.forward_native(x, residual, post_layer_mix, comb_res_mix)
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
return mhc_kernels.mhc_post_torch(
x,
residual,
@@ -154,9 +215,6 @@ class MHCPostOp(CustomOp):
comb_res_mix,
)
def forward_native(self, *args, **kwargs):
raise NotImplementedError("Native implementation of mhc_post is not available")
# --8<-- [start:hc_head]
@CustomOp.register("hc_head")
@@ -220,17 +278,32 @@ class HCHeadOp(CustomOp):
out = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
)
torch.ops.vllm.hc_head_triton(
hs_flat,
hc_fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_norm_eps,
hc_eps,
hc_mult,
)
if HAS_TILELANG:
torch.ops.vllm.hc_head_fused_kernel_tilelang(
hs_flat,
hc_fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_norm_eps,
hc_eps,
hc_mult,
)
else:
torch.ops.vllm.hc_head_triton(
hs_flat,
hc_fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_norm_eps,
hc_eps,
hc_mult,
)
return out.view(*outer_shape, hidden_size)
def forward_native(self, *args, **kwargs):
@@ -290,9 +363,42 @@ class MHCFusedPostPreOp(CustomOp):
norm_eps,
)
def forward_hip(self, *args, **kwargs):
raise NotImplementedError(
"Hip implementation of mhc_fused_post_pre is not available"
def forward_hip(
self,
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
tile_n: int = 1,
norm_weight: torch.Tensor | None = None,
norm_eps: float = 0.0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return torch.ops.vllm.mhc_fused_post_pre_tilelang(
x,
residual,
post_layer_mix,
comb_res_mix,
fn,
hc_scale,
hc_base,
rms_eps,
hc_pre_eps,
hc_sinkhorn_eps,
hc_post_mult_value,
sinkhorn_repeat,
n_splits,
tile_n,
norm_weight,
norm_eps,
)
def forward_native(self, *args, **kwargs):
+11 -7
View File
@@ -54,6 +54,7 @@ from vllm.models.deepseek_v4.attention import (
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
class DeepseekV4MLP(nn.Module):
@@ -473,6 +474,7 @@ class DeepseekV4DecoderLayer(nn.Module):
self.mhc_pre = MHCPreOp()
self.mhc_post = MHCPostOp()
self.mhc_fused_post_pre = MHCFusedPostPreOp()
self.has_tilelang = has_tilelang()
def hc_pre(
self,
@@ -503,7 +505,7 @@ class DeepseekV4DecoderLayer(nn.Module):
):
return self.mhc_post(x, residual, post, comb)
def _forward_cuda(
def _forward_fused_post_pre(
self,
x: torch.Tensor,
positions: torch.Tensor,
@@ -555,7 +557,7 @@ class DeepseekV4DecoderLayer(nn.Module):
x = self.ffn(x, input_ids)
return x, residual, post_mix, res_mix
def _forward_rocm(
def _forward_unfused_post_pre(
self,
x: torch.Tensor,
positions: torch.Tensor,
@@ -594,12 +596,13 @@ class DeepseekV4DecoderLayer(nn.Module):
) -> tuple[
torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None
]:
if current_platform.is_rocm():
return self._forward_rocm(
if not self.has_tilelang:
return self._forward_unfused_post_pre(
x, positions, input_ids, post_mix, res_mix, residual
)
return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual)
return self._forward_fused_post_pre(
x, positions, input_ids, post_mix, res_mix, residual
)
@support_torch_compile
@@ -682,6 +685,7 @@ class DeepseekV4Model(nn.Module):
requires_grad=False,
)
self.hc_head_op = HCHeadOp()
self.has_tilelang = has_tilelang()
# Pre-hc_head residual stream buffer for the MTP draft. Stable
# address (outside the cudagraph pool) so the copy_ in forward()
# refreshes it correctly across captured shapes.
@@ -748,7 +752,7 @@ class DeepseekV4Model(nn.Module):
res_mix,
residual,
)
if layer is not None and current_platform.is_cuda():
if layer is not None and self.has_tilelang:
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
if not get_pp_group().is_last_rank:
+3 -1
View File
@@ -39,6 +39,7 @@ from vllm.model_executor.models.deepseek_v2 import get_spec_layer_idx_from_weigh
from vllm.model_executor.models.utils import maybe_prefix
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import has_tilelang
from .model import DeepseekV4DecoderLayer
@@ -118,6 +119,7 @@ class DeepSeekV4MultiTokenPredictorLayer(nn.Module):
)
self.hc_head_op = HCHeadOp()
self.has_tilelang = has_tilelang()
def forward(
self,
@@ -144,7 +146,7 @@ class DeepSeekV4MultiTokenPredictorLayer(nn.Module):
hidden_states, residual, post_mix, res_mix = self.mtp_block(
positions=positions, x=hidden_states, input_ids=None
)
if current_platform.is_cuda():
if self.has_tilelang:
hidden_states = self.mtp_block.hc_post(
hidden_states, residual, post_mix, res_mix
)
+9
View File
@@ -592,6 +592,15 @@ class CudaPlatformBase(Platform):
default, rms_norm=rms_norm, fused_add_rms_norm=rms_norm
)
@classmethod
def is_arch_support_pdl(cls) -> bool:
try:
device = torch.cuda.current_device()
major, _ = torch.cuda.get_device_capability(device)
except Exception:
return False
return major >= 9
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
+7
View File
@@ -1016,6 +1016,13 @@ class Platform:
# Native always used by default. Platforms can override this behavior.
return IrOpPriorityConfig.with_default(["native"])
@classmethod
def is_arch_support_pdl(cls) -> bool:
"""
Does the current platform support PDL (Programmatic Dependent Launch)?
"""
return False
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
+1
View File
@@ -430,6 +430,7 @@ def has_triton_kernels() -> bool:
return is_available
@cache
def has_tilelang() -> bool:
"""Whether the optional `tilelang` package is available."""
return _has_module("tilelang")