[ROCm][CI] Extend ROCm quick reduce coverage (#40990)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-05-26 08:57:13 -05:00
committed by GitHub
parent d565357a90
commit 445ded18c1
4 changed files with 869 additions and 38 deletions
+20 -4
View File
@@ -2703,19 +2703,35 @@ steps:
optional: true
working_dir: "/vllm-workspace/"
source_file_dependencies:
- csrc/custom_quickreduce.cu
- csrc/ops.h
- csrc/torch_bindings.cpp
- vllm/distributed/
- vllm/v1/distributed/
- vllm/model_executor/layers/
- vllm/entrypoints/llm.py
- vllm/config/parallel.py
- vllm/model_executor/layers/fused_moe/
- vllm/v1/engine/
- vllm/v1/executor/
- vllm/v1/worker/
- vllm/v1/distributed/
- vllm/v1/attention/backends/
- vllm/v1/attention/selector.py
- tests/distributed/test_context_parallel.py
- tests/v1/distributed/test_dbo.py
- examples/features/data_parallel/data_parallel_offline.py
- vllm/_aiter_ops.py
- vllm/_custom_ops.py
- vllm/platforms/rocm.py
- vllm/envs.py
- examples/offline_inference/data_parallel.py
- tests/distributed/test_context_parallel.py
- tests/distributed/test_rocm_quick_reduce.py
- tests/distributed/test_quick_all_reduce.py
- tests/v1/distributed/test_dbo.py
- tests/utils.py
commands:
- pytest -v -s tests/distributed/test_context_parallel.py
- pytest -v -s tests/v1/distributed/test_dbo.py
- pytest -v -s tests/distributed/test_rocm_quick_reduce.py
- pytest -v -s tests/distributed/test_quick_all_reduce.py
#-------------------------------------------------------- mi355 · entrypoints --------------------------------------------------------#
+73 -3
View File
@@ -25,16 +25,34 @@ from ..utils import (
ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_parallel,
set_random_seed,
)
torch.manual_seed(42)
random.seed(44)
def on_gfx942() -> bool:
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx942 as rocm_on_gfx942
return rocm_on_gfx942()
return False
set_random_seed(42)
_test_size_rng = random.Random(44)
# Size over 8MB is sufficient for custom quick allreduce.
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
test_sizes = [
_test_size_rng.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
]
for i, v in enumerate(test_sizes):
test_sizes[i] -= v % 8
def _assert_quickreduce(fa, inp):
assert fa is not None
assert not fa.disabled
assert fa.should_quick_allreduce(inp)
@pytest.fixture
def envs_cache_disabled():
disable_envs_cache()
@@ -216,11 +234,14 @@ def graph_quickreduce(
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.accelerator.set_device_index(device)
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tp_group().device_group
fa = get_tp_group().device_communicator.qr_comm
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
@@ -246,6 +267,8 @@ def graph_quickreduce(
device_idx = torch.accelerator.current_device_index()
inp1 = torch.randint(1, 23, (sz,), dtype=dtype, device=device_idx)
inp2 = torch.randint(-23, 1, (sz,), dtype=dtype, device=device_idx)
_assert_quickreduce(fa, inp1)
_assert_quickreduce(fa, inp2)
torch.accelerator.synchronize()
graph = torch.cuda.CUDAGraph()
@@ -270,6 +293,8 @@ def eager_quickreduce(
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
device = torch.device(f"cuda:{rank}")
torch.accelerator.set_device_index(device)
@@ -281,12 +306,42 @@ def eager_quickreduce(
inp = torch.tensor(
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
)
_assert_quickreduce(fa, inp)
out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
inp = torch.tensor(
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
)
_assert_quickreduce(fa, inp)
out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
@ray.remote(num_gpus=1, max_calls=1)
def bf16_cast_quickreduce(
monkeypatch: pytest.MonkeyPatch,
tp_size,
pp_size,
rank,
distributed_init_port,
):
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
m.setenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "1")
device = torch.device(f"cuda:{rank}")
torch.accelerator.set_device_index(device)
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
sz = 16 * 1024 * 1024
fa = get_tp_group().device_communicator.qr_comm
inp = torch.tensor(
[1.0 * (i % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
)
_assert_quickreduce(fa, inp)
assert fa.use_fp16_kernels
out = fa.quick_all_reduce(inp)
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
@@ -308,12 +363,27 @@ def test_custom_quick_allreduce(
world_size = tp_size * pipeline_parallel_size
if world_size > torch.accelerator.device_count():
pytest.skip("Not enough GPUs to run the test.")
if test_target is graph_quickreduce and on_gfx942():
pytest.xfail(
"CUDA graph capture with quick reduce hits "
"hipErrorStreamCaptureInvalidated on gfx942"
)
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
)
def test_custom_quick_allreduce_bf16_cast(monkeypatch: pytest.MonkeyPatch):
if torch.accelerator.device_count() < 2:
pytest.skip("Not enough GPUs to run the test.")
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "FP")
multi_process_parallel(monkeypatch, 2, 1, bf16_cast_quickreduce)
def qr_variable_input(rank, world_size):
"""
When the tensor parallelism is set to 4 or 8, frequent changes
+750
View File
@@ -0,0 +1,750 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import importlib
import multiprocessing as mp
import os
import queue
import traceback
from functools import lru_cache
from types import SimpleNamespace
from typing import Literal
from unittest.mock import patch
import pytest
import torch
import torch.distributed as dist
from huggingface_hub import snapshot_download
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
pytestmark = pytest.mark.skipif(
not current_platform.is_rocm(),
reason="ROCm-only quick-reduce tests",
)
MB = 1024 * 1024
WORLD_SIZE = 2
QUANT_LEVELS = ["FP", "INT8", "INT6", "INT4"]
def _log(message: str) -> None:
print(f"[rocm_quick_reduce] {message}", flush=True)
def _reload_envs():
return importlib.reload(envs)
def _make_quick_allreduce(
*,
disabled: bool = False,
world_size: int = 2,
quant_level: str = "FP",
use_fp16_kernels: bool = False,
qr_max_size: int = 64 * MB,
):
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
QuickReduceRegime,
)
qar = QuickAllReduce.__new__(QuickAllReduce)
qar.disabled = disabled
qar.world_size = world_size
qar.use_fp16_kernels = use_fp16_kernels
qar.qr_quant_level = QuickReduceRegime[quant_level]
qar.qr_max_size = qr_max_size
return qar
def _quick_allreduce_worker(
rank: int,
port: int,
quant_level: str,
dtype_name: str,
cast_bf16: bool,
):
os.environ["VLLM_ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_level
os.environ["VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "1" if cast_bf16 else "0"
_log(
f"worker start: rank={rank} quant={quant_level} "
f"dtype={dtype_name} cast_bf16={cast_bf16}"
)
device = torch.device(f"cuda:{rank}")
torch.accelerator.set_device_index(device)
dist.init_process_group(
backend="gloo",
init_method=f"tcp://127.0.0.1:{port}",
rank=rank,
world_size=WORLD_SIZE,
)
qar = None
try:
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
)
qar = QuickAllReduce(group=dist.GroupMember.WORLD, device=rank)
assert not qar.disabled
num_elements = 8 * MB if dtype_name == "float16" else 4 * MB
dtype = getattr(torch, dtype_name)
inp = torch.ones(num_elements, dtype=dtype, device=device)
assert qar.should_quick_allreduce(inp)
if cast_bf16:
assert qar.use_fp16_kernels
out = qar.quick_all_reduce(inp)
assert torch.allclose(out, inp * WORLD_SIZE, atol=2.5, rtol=0.1)
_log(
f"worker complete: rank={rank} quant={quant_level} "
f"dtype={dtype_name} num_elements={num_elements} "
f"use_fp16_kernels={qar.use_fp16_kernels}"
)
finally:
if qar is not None:
qar.close()
if dist.is_initialized():
dist.destroy_process_group()
def _run_two_gpu_quick_allreduce_test(
*,
quant_level: str,
dtype_name: str,
cast_bf16: bool,
):
_log(
f"launch 2-GPU case: quant={quant_level} "
f"dtype={dtype_name} cast_bf16={cast_bf16}"
)
ctx = mp.get_context("spawn")
port = get_open_port()
procs = []
for rank in range(WORLD_SIZE):
proc = ctx.Process(
target=_quick_allreduce_worker,
args=(rank, port, quant_level, dtype_name, cast_bf16),
)
proc.start()
procs.append(proc)
for proc in procs:
proc.join(timeout=60)
assert proc.exitcode == 0, f"worker exited with code {proc.exitcode}"
_log(
f"finished 2-GPU case: quant={quant_level} "
f"dtype={dtype_name} cast_bf16={cast_bf16}"
)
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
E2E_PREFILL_TOKENS = 1024
E2E_MAX_MODEL_LEN = 1536
E2E_GPU_MEMORY_UTILIZATION = 0.3
E2E_KV_CACHE_MEMORY_BYTES = 2 << 30
_BACKGROUND_LINE = (
"Background filler: this archived operations memo repeats a routine status "
"line so the distributed test uses a realistically long prefill."
)
_BACKGROUND_BLOCK = " ".join([_BACKGROUND_LINE] * 48)
def _build_prompt(*, fact_block: str, question: str) -> str:
return (
"Read the archived operations memo below. Most of the memo is filler. "
"Use only the fact block near the end when answering.\n"
f"{_BACKGROUND_BLOCK}\n"
"Fact block:\n"
f"{fact_block}\n"
f"Question: {question}\n"
"Answer in one short sentence."
)
E2E_PROMPTS = [
_build_prompt(
fact_block=(
"- Festival city: Oslo\n- Mascot animal: otter\n- Welcome drink: tea"
),
question="Which city hosts the festival, and what animal is the mascot?",
),
_build_prompt(
fact_block=(
"- Meeting day: Tuesday\n"
"- Planned snack: apricot cake\n"
"- Backup room: Cedar"
),
question="What day is the meeting, and what snack is planned?",
),
]
RECORDED_RESPONSE_TEXTS = (
" The city hosting the festival is Oslo, and the mascot is an otter.",
" The meeting is on Tuesday and the snack planned is apricot cake.",
)
REQUIRED_WORDS = (("oslo", "otter"), ("tuesday", "apricot"))
def _log_prompt_summaries() -> None:
for i, prompt in enumerate(E2E_PROMPTS):
prompt_lines = prompt.splitlines()
fact_block = [line for line in prompt_lines if line.startswith("- ")]
fact_summary = "; ".join(line.removeprefix("- ") for line in fact_block)
_log(f"prompt {i} facts: {fact_summary}")
@lru_cache(maxsize=1)
def _get_model_path() -> str:
try:
path = snapshot_download(repo_id=MODEL_NAME, local_files_only=True)
_log(f"using cached model snapshot: {path}")
return path
except Exception:
path = snapshot_download(repo_id=MODEL_NAME)
_log(f"downloaded model snapshot: {path}")
return path
def _get_hidden_size(model_config) -> int:
hidden_size = getattr(model_config, "hidden_size", None)
if hidden_size is None and hasattr(model_config, "text_config"):
hidden_size = getattr(model_config.text_config, "hidden_size", None)
assert isinstance(hidden_size, int)
return hidden_size
def _check_tp_allreduce_uses_quick_reduce(
self,
num_tokens: int,
dtype_name: str = "float16",
) -> dict[str, int | bool]:
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
assert self.device is not None
qr_comm = get_tp_group().device_communicator.qr_comm
assert qr_comm is not None
assert not qr_comm.disabled
hidden_size = _get_hidden_size(self.model_runner.model.config)
dtype = getattr(torch, dtype_name)
sample = torch.full(
(num_tokens, hidden_size),
fill_value=float(self.rank + 1),
dtype=dtype,
device=self.device,
)
assert qr_comm.should_quick_allreduce(sample)
expected = sample.clone()
reduced = tensor_model_parallel_all_reduce(sample)
dist.all_reduce(expected, group=get_tp_group().device_group)
torch.testing.assert_close(reduced, expected, atol=2.5, rtol=0.1)
stats = {
"rank": self.rank,
"hidden_size": hidden_size,
"num_tokens": num_tokens,
"use_fp16_kernels": qr_comm.use_fp16_kernels,
}
_log(
"worker quick-reduce check: "
f"rank={self.rank} hidden_size={hidden_size} "
f"num_tokens={num_tokens} use_fp16_kernels={qr_comm.use_fp16_kernels}"
)
return stats
def _check_quick_reduce_disabled(self) -> int:
from vllm.distributed.parallel_state import get_tp_group
qr_comm = get_tp_group().device_communicator.qr_comm
assert qr_comm is not None
assert qr_comm.disabled
_log(f"worker confirmed quick reduce is disabled: rank={self.rank}")
return self.rank
def _collect_generations(outputs) -> list[tuple[tuple[int, ...], str]]:
return [
(tuple(output.outputs[0].token_ids), output.outputs[0].text)
for output in outputs
]
def _shutdown_llm(llm: LLM | None) -> None:
if llm is None:
cleanup_dist_env_and_memory()
return
with contextlib.suppress(Exception):
llm.llm_engine.engine_core.shutdown()
del llm
cleanup_dist_env_and_memory()
def _log_generations(
label: str,
generations: list[tuple[tuple[int, ...], str]],
) -> None:
for i, (token_ids, text) in enumerate(generations):
_log(f"{label} prompt {i} token ids: {list(token_ids)}")
_log(f"{label} prompt {i} text: {text!r}")
def _assert_required_words(
label: str,
generations: list[tuple[tuple[int, ...], str]],
) -> None:
for i, (_, text) in enumerate(generations):
lowered = text.lower()
missing = [word for word in REQUIRED_WORDS[i] if word not in lowered]
assert not missing, (
f"{label} prompt {i} is missing required words {missing}. "
f"Observed text: {text!r}"
)
def _collect_soft_mismatches(
baseline_generations: list[tuple[tuple[int, ...], str]],
quick_reduce_generations: list[tuple[tuple[int, ...], str]],
) -> list[str]:
mismatches = []
for i, (_, text) in enumerate(baseline_generations):
expected = RECORDED_RESPONSE_TEXTS[i]
if text != expected:
mismatches.append(
f"baseline prompt {i} drifted from the recorded response.\n"
f"expected={expected!r}\nactual={text!r}"
)
for i, (_, text) in enumerate(quick_reduce_generations):
expected = RECORDED_RESPONSE_TEXTS[i]
if text != expected:
mismatches.append(
f"quick-reduce prompt {i} drifted from the recorded response.\n"
f"expected={expected!r}\nactual={text!r}"
)
for i, ((_, baseline_text), (_, quick_reduce_text)) in enumerate(
zip(baseline_generations, quick_reduce_generations)
):
if baseline_text != quick_reduce_text:
mismatches.append(
f"baseline and quick-reduce responses differ for prompt {i}.\n"
f"baseline={baseline_text!r}\nquick_reduce={quick_reduce_text!r}"
)
return mismatches
def _run_generation(
*,
backend: Literal["mp", "ray"],
quant_mode: str,
expect_quick_reduce: bool,
) -> list[tuple[tuple[int, ...], str]]:
llm = None
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
m.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
model_path = _get_model_path()
_log(
f"starting generation: backend={backend} quant={quant_mode} "
f"gpu_memory_utilization={E2E_GPU_MEMORY_UTILIZATION} "
f"kv_cache_bytes={E2E_KV_CACHE_MEMORY_BYTES} model={model_path}"
)
try:
llm = LLM(
model=model_path,
tokenizer=model_path,
tensor_parallel_size=2,
distributed_executor_backend=backend,
dtype="half",
enforce_eager=True,
max_model_len=E2E_MAX_MODEL_LEN,
max_num_seqs=len(E2E_PROMPTS),
gpu_memory_utilization=E2E_GPU_MEMORY_UTILIZATION,
kv_cache_memory_bytes=E2E_KV_CACHE_MEMORY_BYTES,
seed=0,
)
if not expect_quick_reduce:
assert llm.collective_rpc(_check_quick_reduce_disabled) == [0, 1]
if expect_quick_reduce:
worker_stats = llm.collective_rpc(
_check_tp_allreduce_uses_quick_reduce,
args=(E2E_PREFILL_TOKENS,),
)
assert [stat["rank"] for stat in worker_stats] == [0, 1]
worker_summary = "; ".join(
"rank={rank} hidden_size={hidden_size} num_tokens={num_tokens} "
"use_fp16_kernels={use_fp16_kernels}".format(**stat)
for stat in worker_stats
)
_log(f"{backend} quick-reduce worker checks: {worker_summary}")
outputs = llm.generate(
E2E_PROMPTS,
SamplingParams(
temperature=0.0,
max_tokens=20,
stop=["\nAnswer:", " Answer:"],
),
use_tqdm=False,
)
generations = _collect_generations(outputs)
assert all(text.strip() for _, text in generations)
_log_generations(f"{backend} {quant_mode}", generations)
return generations
finally:
_shutdown_llm(llm)
def _run_quick_reduce_llm_e2e_in_subprocess(
*,
backend: Literal["mp", "ray"],
) -> str | None:
_log(f"running LLM e2e: backend={backend}")
_log_prompt_summaries()
baseline_outputs = _run_generation(
backend=backend,
quant_mode="NONE",
expect_quick_reduce=False,
)
quick_reduce_outputs = _run_generation(
backend=backend,
quant_mode="FP",
expect_quick_reduce=True,
)
_assert_required_words("baseline", baseline_outputs)
_assert_required_words("quick-reduce", quick_reduce_outputs)
mismatches = _collect_soft_mismatches(baseline_outputs, quick_reduce_outputs)
if mismatches:
details = "\n\n".join(mismatches)
_log(f"soft response mismatch:\n{details}")
return details
_log(f"LLM e2e backend={backend} matched the recorded responses exactly")
return None
def _quick_reduce_llm_e2e_worker(
result_queue: mp.Queue,
backend: Literal["mp", "ray"],
) -> None:
try:
xfail_reason = _run_quick_reduce_llm_e2e_in_subprocess(backend=backend)
except Exception:
result_queue.put({"status": "error", "reason": traceback.format_exc()})
raise
else:
if xfail_reason is not None:
result_queue.put({"status": "xfail", "reason": xfail_reason})
else:
result_queue.put({"status": "ok"})
def run_quick_reduce_llm_e2e(
*,
backend: Literal["mp", "ray"],
) -> None:
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
proc = ctx.Process(
target=_quick_reduce_llm_e2e_worker,
args=(result_queue, backend),
)
proc.start()
proc.join(timeout=600)
try:
result = result_queue.get(timeout=5)
except queue.Empty as exc:
if proc.exitcode != 0:
raise AssertionError(
f"quick-reduce llm e2e subprocess failed for backend={backend} "
f"with exit code {proc.exitcode} and produced no result"
) from exc
raise AssertionError(
f"quick-reduce llm e2e subprocess produced no result for backend={backend}"
) from exc
if result["status"] == "xfail":
pytest.xfail(result["reason"])
if result["status"] == "error":
raise AssertionError(
f"quick-reduce llm e2e subprocess failed for backend={backend}:\n"
f"{result['reason']}"
)
assert proc.exitcode == 0, (
f"quick-reduce llm e2e subprocess failed for backend={backend} "
f"with exit code {proc.exitcode}"
)
def test_quick_reduce_regime_values():
from vllm.distributed.device_communicators.quick_all_reduce import QuickReduceRegime
assert QuickReduceRegime.FP.value == 0
assert QuickReduceRegime.INT8.value == 1
assert QuickReduceRegime.INT6.value == 2
assert QuickReduceRegime.INT4.value == 3
assert QuickReduceRegime.NONE.value == 4
def test_quick_reduce_regime_names():
from vllm.distributed.device_communicators.quick_all_reduce import QuickReduceRegime
assert set(QuickReduceRegime.__members__) == {"FP", "INT8", "INT6", "INT4", "NONE"}
@pytest.mark.parametrize("quant_level", QUANT_LEVELS + ["NONE"])
def test_quick_reduce_quantization_env_var(monkeypatch, quant_level):
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_level)
reloaded_envs = _reload_envs()
assert quant_level == reloaded_envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
def test_quick_reduce_quantization_default(monkeypatch):
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", raising=False)
reloaded_envs = _reload_envs()
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION == "NONE"
@pytest.mark.parametrize("cast_bf16", [True, False])
def test_quick_reduce_cast_bf16_to_fp16_env_var(monkeypatch, cast_bf16):
monkeypatch.setenv(
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "1" if cast_bf16 else "0"
)
reloaded_envs = _reload_envs()
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 is cast_bf16
def test_quick_reduce_cast_bf16_to_fp16_default(monkeypatch):
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", raising=False)
reloaded_envs = _reload_envs()
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 is True
@pytest.mark.parametrize("max_mb", [128, 512, 2048, None])
def test_quick_reduce_max_size_env_var(monkeypatch, max_mb):
if max_mb is None:
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", raising=False)
else:
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", str(max_mb))
reloaded_envs = _reload_envs()
assert max_mb == reloaded_envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
def test_quick_reduce_max_size_default(monkeypatch):
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", raising=False)
reloaded_envs = _reload_envs()
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is None
@pytest.mark.parametrize(
("gcn_arch_name", "expected"),
[
("gfx942", True),
("gfx950", True),
("gfx90a", False),
("", False),
],
)
def test_quick_allreduce_rocm_arch_available(gcn_arch_name, expected):
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
qar = QuickAllReduce.__new__(QuickAllReduce)
qar.disabled = True
with (
patch(
"vllm.distributed.device_communicators.quick_all_reduce.current_platform."
"is_rocm",
return_value=True,
),
patch(
"torch.cuda.get_device_properties",
return_value=SimpleNamespace(gcnArchName=gcn_arch_name),
),
):
assert qar._rocm_arch_available() is expected
def test_quick_allreduce_rocm_arch_available_handles_probe_failure():
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
qar = QuickAllReduce.__new__(QuickAllReduce)
qar.disabled = True
with (
patch(
"vllm.distributed.device_communicators.quick_all_reduce.current_platform."
"is_rocm",
return_value=True,
),
patch("torch.cuda.get_device_properties", side_effect=RuntimeError),
):
assert qar._rocm_arch_available() is False
def test_quick_allreduce_rejects_disabled():
qar = _make_quick_allreduce(disabled=True)
inp = torch.zeros(1024, dtype=torch.float16)
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_rejects_unsupported_dtype():
qar = _make_quick_allreduce()
inp = torch.zeros(1024 * 1024, dtype=torch.float32)
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_rejects_non_aligned_input():
qar = _make_quick_allreduce()
inp = torch.zeros(5, dtype=torch.float16)
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_rejects_non_contiguous_input():
qar = _make_quick_allreduce()
inp = torch.zeros((1024, 1024), dtype=torch.float16)[:, ::2]
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_rejects_input_smaller_than_threshold():
qar = _make_quick_allreduce()
inp = torch.zeros((MB // 2) - 8, dtype=torch.float16)
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_accepts_input_at_threshold():
qar = _make_quick_allreduce()
inp = torch.zeros(MB // 2, dtype=torch.float16)
assert qar.should_quick_allreduce(inp) is True
def test_quick_allreduce_rejects_input_larger_than_max_size():
qar = _make_quick_allreduce(qr_max_size=1 * MB)
inp = torch.zeros(MB, dtype=torch.float16)
assert qar.should_quick_allreduce(inp) is False
def test_quick_allreduce_bf16_uses_fp16_threshold_when_cast_enabled():
inp = torch.zeros(MB // 2, dtype=torch.bfloat16)
without_cast = _make_quick_allreduce(use_fp16_kernels=False)
with_cast = _make_quick_allreduce(use_fp16_kernels=True)
assert without_cast.should_quick_allreduce(inp) is False
assert with_cast.should_quick_allreduce(inp) is True
def test_quick_allreduce_supported_world_sizes():
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
assert QuickAllReduce._SUPPORTED_WORLD_SIZES == [2, 4, 8]
def test_quick_allreduce_supported_dtypes():
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
assert [torch.float16, torch.bfloat16] == QuickAllReduce._SUPPORTED_DTYPES
def test_quick_allreduce_min_size_table():
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
for dtype in [torch.float16, torch.bfloat16]:
for world_size in QuickAllReduce._SUPPORTED_WORLD_SIZES:
min_sizes = QuickAllReduce._QR_MIN_SIZE[(dtype, world_size)]
assert len(min_sizes) == 4
assert all(size > 0 for size in min_sizes)
def test_qr_max_size():
from vllm import _custom_ops as ops
max_size = ops.qr_max_size()
assert isinstance(max_size, int)
assert max_size > 0
@pytest.mark.skipif(
current_platform.device_count() < WORLD_SIZE,
reason="requires 2 ROCm GPUs",
)
@pytest.mark.parametrize("quant_level", QUANT_LEVELS)
def test_quick_allreduce_two_gpu_correctness(quant_level):
_log(f"two-GPU correctness case: quant={quant_level}")
_run_two_gpu_quick_allreduce_test(
quant_level=quant_level,
dtype_name="float16",
cast_bf16=False,
)
@pytest.mark.skipif(
current_platform.device_count() < WORLD_SIZE,
reason="requires 2 ROCm GPUs",
)
def test_quick_allreduce_bf16_cast_mode():
_log("BF16 cast case")
_run_two_gpu_quick_allreduce_test(
quant_level="FP",
dtype_name="bfloat16",
cast_bf16=True,
)
@pytest.mark.skipif(
current_platform.device_count() < WORLD_SIZE,
reason="requires 2 ROCm GPUs",
)
def test_quick_allreduce_llm_e2e():
_log("LLM e2e case: backend=mp")
run_quick_reduce_llm_e2e(backend="mp")
@pytest.mark.skipif(
current_platform.device_count() < WORLD_SIZE,
reason="requires 2 ROCm GPUs",
)
def test_quick_allreduce_llm_e2e_ray():
_log("LLM e2e case: backend=ray")
run_quick_reduce_llm_e2e(backend="ray")
+26 -31
View File
@@ -1363,43 +1363,38 @@ def multi_process_parallel(
) -> None:
import ray
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
# NOTE: Force ray not to use gitignore file as excluding, otherwise
# it will not move .so files to working dir.
# So we have to manually add some of large directories
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
# Using ray helps debugging the error when it failed as compared to
# multiprocessing. For local Ray workers, putting the repo root on
# PYTHONPATH is enough and avoids uploading the full source tree, which
# exceeds Ray's working_dir package size limit on CI.
env_vars = {
"PYTHONPATH": os.pathsep.join(
filter(None, [str(VLLM_PATH), os.environ.get("PYTHONPATH")])
),
**{env_var: "1" for env_var in current_platform.ray_noset_device_env_vars},
}
ray.init(
runtime_env={
"working_dir": VLLM_PATH,
"excludes": [
"build",
".git",
"cmake-build-*",
"shellcheck",
"dist",
"ep_kernels_workspace",
],
"env_vars": env_vars,
}
)
distributed_init_port = get_open_port()
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(
monkeypatch,
tp_size,
pp_size,
rank,
distributed_init_port,
),
)
ray.get(refs)
ray.shutdown()
try:
refs = []
for rank in range(tp_size * pp_size):
refs.append(
test_target.remote(
monkeypatch,
tp_size,
pp_size,
rank,
distributed_init_port,
),
)
ray.get(refs)
finally:
ray.shutdown()
@contextmanager