[ROCm] mori: add InterNodeV1LL inter-node kernel selection via VLLM_MORI_INTERNODE_KERNEL (#41751)

Signed-off-by: jatseng-ai <jatseng@amd.com>
This commit is contained in:
jatseng-ai
2026-05-27 09:33:32 -07:00
committed by GitHub
parent 41688e2dc7
commit 05c50c721e
7 changed files with 36 additions and 14 deletions
@@ -249,7 +249,7 @@ class Config:
def needs_mori(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "mori"
return info.backend in ("mori_high_throughput", "mori_low_latency")
def all2all_backend(self):
info = prepare_finalize_info(self.prepare_finalize_type)
@@ -232,7 +232,7 @@ if has_mori():
standard_format,
fp8_types,
blocked_quantization_support=True,
backend="mori",
backend="mori_high_throughput",
supports_apply_weight_on_input=False,
)
+5 -3
View File
@@ -92,7 +92,7 @@ PARALLEL_COMBOS = [
BACKENDS = ["allgather_reducescatter"]
if has_mori():
BACKENDS += ["mori"]
BACKENDS += ["mori_high_throughput", "mori_low_latency"]
if has_flashinfer_nvlink_two_sided():
BACKENDS += ["flashinfer_nvlink_two_sided"]
@@ -118,7 +118,8 @@ QUANT_METHODS = [
# fmt: off
BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
"allgather_reducescatter": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501
"mori": {None, "fp8", "modelopt_fp8"},
"mori_high_throughput": {None, "fp8", "modelopt_fp8"},
"mori_low_latency": {None, "fp8", "modelopt_fp8"},
"flashinfer_nvlink_two_sided": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501
"flashinfer_nvlink_one_sided": {None, "modelopt_fp4"}, # noqa: E501
"deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501
@@ -129,7 +130,8 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
# Map from backend -> (DP/EP support, DP support, TP support, SP support)
BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool, bool]] = {
"allgather_reducescatter": (True, True, True, True),
"mori": (True, False, False, True),
"mori_high_throughput": (True, False, False, True),
"mori_low_latency": (True, False, False, True),
"flashinfer_nvlink_two_sided": (False, True, False, False),
"flashinfer_nvlink_one_sided": (False, True, False, False),
"deepep_low_latency": (True, False, False, True),
+6 -3
View File
@@ -42,7 +42,8 @@ All2AllBackend = Literal[
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"mori_high_throughput",
"mori_low_latency",
"nixl_ep",
"allgather_reducescatter",
"flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
@@ -177,7 +178,8 @@ class ParallelConfig:
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "mori": Use mori kernels
- "mori_high_throughput": MoRI EP with InterNodeV1 for multi-node
- "mori_low_latency": MoRI EP with InterNodeV1LL for multi-node
- "nixl_ep": Use nixl-ep kernels
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
@@ -621,7 +623,8 @@ class ParallelConfig:
"allgather_reducescatter",
"deepep_high_throughput",
"deepep_low_latency",
"mori",
"mori_high_throughput",
"mori_low_latency",
"nixl_ep",
)
and self.enable_expert_parallel
@@ -764,14 +764,19 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
class MoriAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
def __init__(self, cpu_group, all2all_backend: str):
assert has_mori(), (
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
" to install MoRI kernels."
) # noqa
assert all2all_backend in (
"mori_high_throughput",
"mori_low_latency",
), f"unsupported MoRI all2all backend: {all2all_backend!r}"
import mori
super().__init__(cpu_group)
self._all2all_backend = all2all_backend
self.handle_cache = Cache()
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
@@ -805,8 +810,12 @@ class MoriAll2AllManager(All2AllManagerBase):
warp_num_per_block = 16
block_num = 80
else:
# multi node
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
# Multi-node: kernel follows --all2all-backend (mirrors deepep_* split).
# mori_low_latency → InterNodeV1LL; mori_high_throughput → V1.
if self._all2all_backend == "mori_low_latency":
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1LL
else:
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
if on_gfx942():
warp_num_per_block = 16
block_num = 32
@@ -137,10 +137,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
elif self.all2all_backend in (
"mori_high_throughput",
"mori_low_latency",
):
from .all2all import MoriAll2AllManager
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
self.all2all_manager = MoriAll2AllManager(
self.cpu_group, self.all2all_backend
)
elif self.all2all_backend == "nixl_ep":
from .all2all import NixlEPAll2AllManager
@@ -1091,7 +1091,10 @@ class FusedMoEParallelConfig:
@property
def use_mori_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "mori"
return self.use_all2all_kernels and self.all2all_backend in (
"mori_high_throughput",
"mori_low_latency",
)
@property
def use_nixl_ep_kernels(self):