[Kernel][Helion] Fix inductor fusion of Helion HOP (#39944)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Yanan Cao
2026-04-16 04:41:26 -07:00
committed by GitHub
parent 9965f501a8
commit edc3648966
5 changed files with 59 additions and 6 deletions
+1 -1
View File
@@ -769,7 +769,7 @@ steps:
- tests/kernels/helion/
- vllm/platforms/rocm.py
commands:
- pip install helion==0.3.3
- pip install helion==1.0.0
- pytest -v -s kernels/helion/
+1 -1
View File
@@ -155,7 +155,7 @@ steps:
- vllm/utils/import_utils.py
- tests/kernels/helion/
commands:
- pip install helion==0.3.3
- pip install helion==1.0.0
- pytest -v -s kernels/helion/
+1 -1
View File
@@ -1104,7 +1104,7 @@ setup(
# NOTE: When updating helion version, also update CI files:
# - .buildkite/test_areas/kernels.yaml
# - .buildkite/test-amd.yaml
"helion": ["helion==0.3.3"],
"helion": ["helion==1.0.0"],
# Optional deps for gRPC server (vllm serve --grpc)
"grpc": ["smg-grpc-servicer[vllm] >= 0.5.0"],
# Optional deps for OpenTelemetry tracing
+48
View File
@@ -36,9 +36,11 @@ from vllm.kernels.helion.register import (
)
if _HOP_AVAILABLE:
from helion._compat import supports_torch_compile_fusion
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_wrapper_mutation,
)
from torch._inductor.utils import run_and_get_code
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -1003,3 +1005,49 @@ class TestTorchCompileHOP:
"Compiled execution result doesn't match eager execution. "
f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
)
@pytest.mark.skipif(
not (_HOP_AVAILABLE and supports_torch_compile_fusion()),
reason="Requires PyTorch with Helion inductor fusion support",
)
def test_inductor_backend_compiles_helion_hop(self):
"""Test torch.compile with inductor backend and Helion fusion enabled."""
configs = {"default": helion.Config(block_sizes=[4, 4])}
with dummy_kernel_registry(configs=configs) as register:
add_helion_kernel = register(
op_name="test_inductor_add_kernel",
config_picker=lambda args, keys: "default",
helion_settings=helion.Settings(
torch_compile_fusion=True, static_shapes=False
),
)(_add_kernel)
def f(x, y):
x = x * 2.0
y = y + 1.0
out = add_helion_kernel(x, y)
return out.relu()
torch._dynamo.reset()
compiled_f = torch.compile(f, backend="inductor", fullgraph=True)
x = torch.randn(4, 4, device="cuda")
y = torch.randn(4, 4, device="cuda")
compiled_result, source_codes = run_and_get_code(compiled_f, x, y)
eager_result = f(x, y)
assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
"Inductor-compiled result doesn't match eager execution. "
f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
)
# With fusion enabled, prologue/epilogue ops should be fused into
# a single triton kernel rather than generating separate kernels.
kernel_count = sum(code.count("@triton.jit") for code in source_codes)
assert kernel_count == 1, (
f"Expected 1 fused triton kernel, got {kernel_count}. "
"Prologue/epilogue ops were not fused into the Helion kernel."
)
+8 -3
View File
@@ -63,8 +63,10 @@ from helion.runtime.settings import default_autotuner_fn
_HOP_AVAILABLE = requires_torch_version("2.11")
if _HOP_AVAILABLE:
from helion._compat import supports_torch_compile_fusion
from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table
from helion._compiler._dynamo.variables import HelionKernelVariable
from helion.runtime.kernel import Kernel
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.variables.builder import VariableBuilder
@@ -475,19 +477,22 @@ if _HOP_AVAILABLE:
"""Register HelionKernelWrapper with Dynamo's VariableBuilder.
When Dynamo encounters a HelionKernelWrapper during tracing, this
extracts the underlying Helion Kernel, registers it in the side table,
and returns Helion's own HelionKernelVariable to handle HOP emission.
extracts the underlying Helion Kernel and delegates to Helion's own
registered Kernel handler, which handles HOP emission, side table
registration, and inductor lowering setup.
"""
def wrap_helion_kernel_wrapper(
builder: VariableBuilder, value: HelionKernelWrapper
):
kernel = value.get_configured_op()._decorated_kernel
if supports_torch_compile_fusion():
helion_handler = VariableBuilder._type_dispatch()[Kernel]
return helion_handler(builder, kernel)
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
builder.install_guards(GuardBuilder.ID_MATCH)
return HelionKernelVariable(kernel, kernel_idx, source=builder.source)
# Register with Dynamo's type dispatch system
dispatch = VariableBuilder._type_dispatch()
dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper