mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm] mori: add InterNodeV1LL inter-node kernel selection via VLLM_MORI_INTERNODE_KERNEL (#41751)
Signed-off-by: jatseng-ai <jatseng@amd.com>
This commit is contained in:
@@ -249,7 +249,7 @@ class Config:
|
||||
|
||||
def needs_mori(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
return info.backend == "mori"
|
||||
return info.backend in ("mori_high_throughput", "mori_low_latency")
|
||||
|
||||
def all2all_backend(self):
|
||||
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||
|
||||
@@ -232,7 +232,7 @@ if has_mori():
|
||||
standard_format,
|
||||
fp8_types,
|
||||
blocked_quantization_support=True,
|
||||
backend="mori",
|
||||
backend="mori_high_throughput",
|
||||
supports_apply_weight_on_input=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ PARALLEL_COMBOS = [
|
||||
BACKENDS = ["allgather_reducescatter"]
|
||||
|
||||
if has_mori():
|
||||
BACKENDS += ["mori"]
|
||||
BACKENDS += ["mori_high_throughput", "mori_low_latency"]
|
||||
|
||||
if has_flashinfer_nvlink_two_sided():
|
||||
BACKENDS += ["flashinfer_nvlink_two_sided"]
|
||||
@@ -118,7 +118,8 @@ QUANT_METHODS = [
|
||||
# fmt: off
|
||||
BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
|
||||
"allgather_reducescatter": {None, "fp8", "modelopt_fp8", "modelopt_fp4"}, # noqa: E501
|
||||
"mori": {None, "fp8", "modelopt_fp8"},
|
||||
"mori_high_throughput": {None, "fp8", "modelopt_fp8"},
|
||||
"mori_low_latency": {None, "fp8", "modelopt_fp8"},
|
||||
"flashinfer_nvlink_two_sided": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501
|
||||
"flashinfer_nvlink_one_sided": {None, "modelopt_fp4"}, # noqa: E501
|
||||
"deepep_low_latency": {None, "fp8_blocked", "modelopt_fp4"}, # noqa: E501
|
||||
@@ -129,7 +130,8 @@ BACKEND_SUPPORTED_QUANTS: dict[str, set[str | None]] = {
|
||||
# Map from backend -> (DP/EP support, DP support, TP support, SP support)
|
||||
BACKEND_EP_DP_TP_SUPPORT: dict[str, tuple[bool, bool, bool, bool]] = {
|
||||
"allgather_reducescatter": (True, True, True, True),
|
||||
"mori": (True, False, False, True),
|
||||
"mori_high_throughput": (True, False, False, True),
|
||||
"mori_low_latency": (True, False, False, True),
|
||||
"flashinfer_nvlink_two_sided": (False, True, False, False),
|
||||
"flashinfer_nvlink_one_sided": (False, True, False, False),
|
||||
"deepep_low_latency": (True, False, False, True),
|
||||
|
||||
@@ -42,7 +42,8 @@ All2AllBackend = Literal[
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"mori_high_throughput",
|
||||
"mori_low_latency",
|
||||
"nixl_ep",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv", # temporary alias for flashinfer_nvlink_two_sided
|
||||
@@ -177,7 +178,8 @@ class ParallelConfig:
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels
|
||||
- "deepep_low_latency": Use deepep low-latency kernels
|
||||
- "mori": Use mori kernels
|
||||
- "mori_high_throughput": MoRI EP with InterNodeV1 for multi-node
|
||||
- "mori_low_latency": MoRI EP with InterNodeV1LL for multi-node
|
||||
- "nixl_ep": Use nixl-ep kernels
|
||||
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
|
||||
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
|
||||
@@ -621,7 +623,8 @@ class ParallelConfig:
|
||||
"allgather_reducescatter",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"mori_high_throughput",
|
||||
"mori_low_latency",
|
||||
"nixl_ep",
|
||||
)
|
||||
and self.enable_expert_parallel
|
||||
|
||||
@@ -764,14 +764,19 @@ class FlashInferNVLinkOneSidedManager(All2AllManagerBase):
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, all2all_backend: str):
|
||||
assert has_mori(), (
|
||||
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
|
||||
" to install MoRI kernels."
|
||||
) # noqa
|
||||
assert all2all_backend in (
|
||||
"mori_high_throughput",
|
||||
"mori_low_latency",
|
||||
), f"unsupported MoRI all2all backend: {all2all_backend!r}"
|
||||
import mori
|
||||
|
||||
super().__init__(cpu_group)
|
||||
self._all2all_backend = all2all_backend
|
||||
self.handle_cache = Cache()
|
||||
|
||||
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
|
||||
@@ -805,8 +810,12 @@ class MoriAll2AllManager(All2AllManagerBase):
|
||||
warp_num_per_block = 16
|
||||
block_num = 80
|
||||
else:
|
||||
# multi node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
|
||||
# Multi-node: kernel follows --all2all-backend (mirrors deepep_* split).
|
||||
# mori_low_latency → InterNodeV1LL; mori_high_throughput → V1.
|
||||
if self._all2all_backend == "mori_low_latency":
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1LL
|
||||
else:
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
|
||||
if on_gfx942():
|
||||
warp_num_per_block = 16
|
||||
block_num = 32
|
||||
|
||||
@@ -137,10 +137,15 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
elif self.all2all_backend in (
|
||||
"mori_high_throughput",
|
||||
"mori_low_latency",
|
||||
):
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
self.all2all_manager = MoriAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = MoriAll2AllManager(
|
||||
self.cpu_group, self.all2all_backend
|
||||
)
|
||||
elif self.all2all_backend == "nixl_ep":
|
||||
from .all2all import NixlEPAll2AllManager
|
||||
|
||||
|
||||
@@ -1091,7 +1091,10 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_mori_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "mori"
|
||||
return self.use_all2all_kernels and self.all2all_backend in (
|
||||
"mori_high_throughput",
|
||||
"mori_low_latency",
|
||||
)
|
||||
|
||||
@property
|
||||
def use_nixl_ep_kernels(self):
|
||||
|
||||
Reference in New Issue
Block a user