mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[ROCm][CI] Extend ROCm quick reduce coverage (#40990)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -2703,19 +2703,35 @@ steps:
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/"
|
||||
source_file_dependencies:
|
||||
- csrc/custom_quickreduce.cu
|
||||
- csrc/ops.h
|
||||
- csrc/torch_bindings.cpp
|
||||
- vllm/distributed/
|
||||
- vllm/v1/distributed/
|
||||
- vllm/model_executor/layers/
|
||||
- vllm/entrypoints/llm.py
|
||||
- vllm/config/parallel.py
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
- vllm/v1/engine/
|
||||
- vllm/v1/executor/
|
||||
- vllm/v1/worker/
|
||||
- vllm/v1/distributed/
|
||||
- vllm/v1/attention/backends/
|
||||
- vllm/v1/attention/selector.py
|
||||
- tests/distributed/test_context_parallel.py
|
||||
- tests/v1/distributed/test_dbo.py
|
||||
- examples/features/data_parallel/data_parallel_offline.py
|
||||
- vllm/_aiter_ops.py
|
||||
- vllm/_custom_ops.py
|
||||
- vllm/platforms/rocm.py
|
||||
- vllm/envs.py
|
||||
- examples/offline_inference/data_parallel.py
|
||||
- tests/distributed/test_context_parallel.py
|
||||
- tests/distributed/test_rocm_quick_reduce.py
|
||||
- tests/distributed/test_quick_all_reduce.py
|
||||
- tests/v1/distributed/test_dbo.py
|
||||
- tests/utils.py
|
||||
commands:
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||
- pytest -v -s tests/distributed/test_rocm_quick_reduce.py
|
||||
- pytest -v -s tests/distributed/test_quick_all_reduce.py
|
||||
|
||||
#-------------------------------------------------------- mi355 · entrypoints --------------------------------------------------------#
|
||||
|
||||
|
||||
@@ -25,16 +25,34 @@ from ..utils import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_parallel,
|
||||
set_random_seed,
|
||||
)
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
|
||||
def on_gfx942() -> bool:
|
||||
if current_platform.is_rocm():
|
||||
from vllm.platforms.rocm import on_gfx942 as rocm_on_gfx942
|
||||
|
||||
return rocm_on_gfx942()
|
||||
return False
|
||||
|
||||
|
||||
set_random_seed(42)
|
||||
_test_size_rng = random.Random(44)
|
||||
# Size over 8MB is sufficient for custom quick allreduce.
|
||||
test_sizes = [random.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)]
|
||||
test_sizes = [
|
||||
_test_size_rng.randint(8 * 1024 * 1024, 10 * 1024 * 1024) for _ in range(8)
|
||||
]
|
||||
for i, v in enumerate(test_sizes):
|
||||
test_sizes[i] -= v % 8
|
||||
|
||||
|
||||
def _assert_quickreduce(fa, inp):
|
||||
assert fa is not None
|
||||
assert not fa.disabled
|
||||
assert fa.should_quick_allreduce(inp)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def envs_cache_disabled():
|
||||
disable_envs_cache()
|
||||
@@ -216,11 +234,14 @@ def graph_quickreduce(
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
fa = get_tp_group().device_communicator.qr_comm
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
@@ -246,6 +267,8 @@ def graph_quickreduce(
|
||||
device_idx = torch.accelerator.current_device_index()
|
||||
inp1 = torch.randint(1, 23, (sz,), dtype=dtype, device=device_idx)
|
||||
inp2 = torch.randint(-23, 1, (sz,), dtype=dtype, device=device_idx)
|
||||
_assert_quickreduce(fa, inp1)
|
||||
_assert_quickreduce(fa, inp2)
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
@@ -270,6 +293,8 @@ def eager_quickreduce(
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
|
||||
@@ -281,12 +306,42 @@ def eager_quickreduce(
|
||||
inp = torch.tensor(
|
||||
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.float16, device=device
|
||||
)
|
||||
_assert_quickreduce(fa, inp)
|
||||
out = fa.quick_all_reduce(inp)
|
||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||
|
||||
inp = torch.tensor(
|
||||
[1.0 * ((i) % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
||||
)
|
||||
_assert_quickreduce(fa, inp)
|
||||
out = fa.quick_all_reduce(inp)
|
||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def bf16_cast_quickreduce(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tp_size,
|
||||
pp_size,
|
||||
rank,
|
||||
distributed_init_port,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("HIP_VISIBLE_DEVICES", raising=False)
|
||||
m.delenv("ROCR_VISIBLE_DEVICES", raising=False)
|
||||
m.setenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "1")
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank, distributed_init_port)
|
||||
|
||||
sz = 16 * 1024 * 1024
|
||||
fa = get_tp_group().device_communicator.qr_comm
|
||||
inp = torch.tensor(
|
||||
[1.0 * (i % 23) for i in range(sz)], dtype=torch.bfloat16, device=device
|
||||
)
|
||||
_assert_quickreduce(fa, inp)
|
||||
assert fa.use_fp16_kernels
|
||||
out = fa.quick_all_reduce(inp)
|
||||
torch.testing.assert_close(out, inp * tp_size, atol=2.5, rtol=0.1)
|
||||
|
||||
@@ -308,12 +363,27 @@ def test_custom_quick_allreduce(
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.accelerator.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
if test_target is graph_quickreduce and on_gfx942():
|
||||
pytest.xfail(
|
||||
"CUDA graph capture with quick reduce hits "
|
||||
"hipErrorStreamCaptureInvalidated on gfx942"
|
||||
)
|
||||
|
||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||
|
||||
multi_process_parallel(monkeypatch, tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_rocm(), reason="only test quick allreduce for rocm"
|
||||
)
|
||||
def test_custom_quick_allreduce_bf16_cast(monkeypatch: pytest.MonkeyPatch):
|
||||
if torch.accelerator.device_count() < 2:
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", "FP")
|
||||
multi_process_parallel(monkeypatch, 2, 1, bf16_cast_quickreduce)
|
||||
|
||||
|
||||
def qr_variable_input(rank, world_size):
|
||||
"""
|
||||
When the tensor parallelism is set to 4 or 8, frequent changes
|
||||
|
||||
@@ -0,0 +1,750 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import importlib
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import queue
|
||||
import traceback
|
||||
from functools import lru_cache
|
||||
from types import SimpleNamespace
|
||||
from typing import Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not current_platform.is_rocm(),
|
||||
reason="ROCm-only quick-reduce tests",
|
||||
)
|
||||
|
||||
MB = 1024 * 1024
|
||||
WORLD_SIZE = 2
|
||||
QUANT_LEVELS = ["FP", "INT8", "INT6", "INT4"]
|
||||
|
||||
|
||||
def _log(message: str) -> None:
|
||||
print(f"[rocm_quick_reduce] {message}", flush=True)
|
||||
|
||||
|
||||
def _reload_envs():
|
||||
return importlib.reload(envs)
|
||||
|
||||
|
||||
def _make_quick_allreduce(
|
||||
*,
|
||||
disabled: bool = False,
|
||||
world_size: int = 2,
|
||||
quant_level: str = "FP",
|
||||
use_fp16_kernels: bool = False,
|
||||
qr_max_size: int = 64 * MB,
|
||||
):
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
QuickReduceRegime,
|
||||
)
|
||||
|
||||
qar = QuickAllReduce.__new__(QuickAllReduce)
|
||||
qar.disabled = disabled
|
||||
qar.world_size = world_size
|
||||
qar.use_fp16_kernels = use_fp16_kernels
|
||||
qar.qr_quant_level = QuickReduceRegime[quant_level]
|
||||
qar.qr_max_size = qr_max_size
|
||||
return qar
|
||||
|
||||
|
||||
def _quick_allreduce_worker(
|
||||
rank: int,
|
||||
port: int,
|
||||
quant_level: str,
|
||||
dtype_name: str,
|
||||
cast_bf16: bool,
|
||||
):
|
||||
os.environ["VLLM_ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_level
|
||||
os.environ["VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "1" if cast_bf16 else "0"
|
||||
_log(
|
||||
f"worker start: rank={rank} quant={quant_level} "
|
||||
f"dtype={dtype_name} cast_bf16={cast_bf16}"
|
||||
)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.accelerator.set_device_index(device)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=f"tcp://127.0.0.1:{port}",
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE,
|
||||
)
|
||||
|
||||
qar = None
|
||||
try:
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import (
|
||||
QuickAllReduce,
|
||||
)
|
||||
|
||||
qar = QuickAllReduce(group=dist.GroupMember.WORLD, device=rank)
|
||||
assert not qar.disabled
|
||||
|
||||
num_elements = 8 * MB if dtype_name == "float16" else 4 * MB
|
||||
|
||||
dtype = getattr(torch, dtype_name)
|
||||
inp = torch.ones(num_elements, dtype=dtype, device=device)
|
||||
|
||||
assert qar.should_quick_allreduce(inp)
|
||||
if cast_bf16:
|
||||
assert qar.use_fp16_kernels
|
||||
|
||||
out = qar.quick_all_reduce(inp)
|
||||
assert torch.allclose(out, inp * WORLD_SIZE, atol=2.5, rtol=0.1)
|
||||
_log(
|
||||
f"worker complete: rank={rank} quant={quant_level} "
|
||||
f"dtype={dtype_name} num_elements={num_elements} "
|
||||
f"use_fp16_kernels={qar.use_fp16_kernels}"
|
||||
)
|
||||
finally:
|
||||
if qar is not None:
|
||||
qar.close()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _run_two_gpu_quick_allreduce_test(
|
||||
*,
|
||||
quant_level: str,
|
||||
dtype_name: str,
|
||||
cast_bf16: bool,
|
||||
):
|
||||
_log(
|
||||
f"launch 2-GPU case: quant={quant_level} "
|
||||
f"dtype={dtype_name} cast_bf16={cast_bf16}"
|
||||
)
|
||||
ctx = mp.get_context("spawn")
|
||||
port = get_open_port()
|
||||
procs = []
|
||||
|
||||
for rank in range(WORLD_SIZE):
|
||||
proc = ctx.Process(
|
||||
target=_quick_allreduce_worker,
|
||||
args=(rank, port, quant_level, dtype_name, cast_bf16),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for proc in procs:
|
||||
proc.join(timeout=60)
|
||||
assert proc.exitcode == 0, f"worker exited with code {proc.exitcode}"
|
||||
_log(
|
||||
f"finished 2-GPU case: quant={quant_level} "
|
||||
f"dtype={dtype_name} cast_bf16={cast_bf16}"
|
||||
)
|
||||
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
E2E_PREFILL_TOKENS = 1024
|
||||
E2E_MAX_MODEL_LEN = 1536
|
||||
E2E_GPU_MEMORY_UTILIZATION = 0.3
|
||||
E2E_KV_CACHE_MEMORY_BYTES = 2 << 30
|
||||
|
||||
_BACKGROUND_LINE = (
|
||||
"Background filler: this archived operations memo repeats a routine status "
|
||||
"line so the distributed test uses a realistically long prefill."
|
||||
)
|
||||
_BACKGROUND_BLOCK = " ".join([_BACKGROUND_LINE] * 48)
|
||||
|
||||
|
||||
def _build_prompt(*, fact_block: str, question: str) -> str:
|
||||
return (
|
||||
"Read the archived operations memo below. Most of the memo is filler. "
|
||||
"Use only the fact block near the end when answering.\n"
|
||||
f"{_BACKGROUND_BLOCK}\n"
|
||||
"Fact block:\n"
|
||||
f"{fact_block}\n"
|
||||
f"Question: {question}\n"
|
||||
"Answer in one short sentence."
|
||||
)
|
||||
|
||||
|
||||
E2E_PROMPTS = [
|
||||
_build_prompt(
|
||||
fact_block=(
|
||||
"- Festival city: Oslo\n- Mascot animal: otter\n- Welcome drink: tea"
|
||||
),
|
||||
question="Which city hosts the festival, and what animal is the mascot?",
|
||||
),
|
||||
_build_prompt(
|
||||
fact_block=(
|
||||
"- Meeting day: Tuesday\n"
|
||||
"- Planned snack: apricot cake\n"
|
||||
"- Backup room: Cedar"
|
||||
),
|
||||
question="What day is the meeting, and what snack is planned?",
|
||||
),
|
||||
]
|
||||
RECORDED_RESPONSE_TEXTS = (
|
||||
" The city hosting the festival is Oslo, and the mascot is an otter.",
|
||||
" The meeting is on Tuesday and the snack planned is apricot cake.",
|
||||
)
|
||||
REQUIRED_WORDS = (("oslo", "otter"), ("tuesday", "apricot"))
|
||||
|
||||
|
||||
def _log_prompt_summaries() -> None:
|
||||
for i, prompt in enumerate(E2E_PROMPTS):
|
||||
prompt_lines = prompt.splitlines()
|
||||
fact_block = [line for line in prompt_lines if line.startswith("- ")]
|
||||
fact_summary = "; ".join(line.removeprefix("- ") for line in fact_block)
|
||||
_log(f"prompt {i} facts: {fact_summary}")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_model_path() -> str:
|
||||
try:
|
||||
path = snapshot_download(repo_id=MODEL_NAME, local_files_only=True)
|
||||
_log(f"using cached model snapshot: {path}")
|
||||
return path
|
||||
except Exception:
|
||||
path = snapshot_download(repo_id=MODEL_NAME)
|
||||
_log(f"downloaded model snapshot: {path}")
|
||||
return path
|
||||
|
||||
|
||||
def _get_hidden_size(model_config) -> int:
|
||||
hidden_size = getattr(model_config, "hidden_size", None)
|
||||
if hidden_size is None and hasattr(model_config, "text_config"):
|
||||
hidden_size = getattr(model_config.text_config, "hidden_size", None)
|
||||
assert isinstance(hidden_size, int)
|
||||
return hidden_size
|
||||
|
||||
|
||||
def _check_tp_allreduce_uses_quick_reduce(
|
||||
self,
|
||||
num_tokens: int,
|
||||
dtype_name: str = "float16",
|
||||
) -> dict[str, int | bool]:
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
assert self.device is not None
|
||||
qr_comm = get_tp_group().device_communicator.qr_comm
|
||||
assert qr_comm is not None
|
||||
assert not qr_comm.disabled
|
||||
|
||||
hidden_size = _get_hidden_size(self.model_runner.model.config)
|
||||
dtype = getattr(torch, dtype_name)
|
||||
sample = torch.full(
|
||||
(num_tokens, hidden_size),
|
||||
fill_value=float(self.rank + 1),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
assert qr_comm.should_quick_allreduce(sample)
|
||||
|
||||
expected = sample.clone()
|
||||
reduced = tensor_model_parallel_all_reduce(sample)
|
||||
dist.all_reduce(expected, group=get_tp_group().device_group)
|
||||
torch.testing.assert_close(reduced, expected, atol=2.5, rtol=0.1)
|
||||
|
||||
stats = {
|
||||
"rank": self.rank,
|
||||
"hidden_size": hidden_size,
|
||||
"num_tokens": num_tokens,
|
||||
"use_fp16_kernels": qr_comm.use_fp16_kernels,
|
||||
}
|
||||
_log(
|
||||
"worker quick-reduce check: "
|
||||
f"rank={self.rank} hidden_size={hidden_size} "
|
||||
f"num_tokens={num_tokens} use_fp16_kernels={qr_comm.use_fp16_kernels}"
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def _check_quick_reduce_disabled(self) -> int:
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
qr_comm = get_tp_group().device_communicator.qr_comm
|
||||
assert qr_comm is not None
|
||||
assert qr_comm.disabled
|
||||
_log(f"worker confirmed quick reduce is disabled: rank={self.rank}")
|
||||
return self.rank
|
||||
|
||||
|
||||
def _collect_generations(outputs) -> list[tuple[tuple[int, ...], str]]:
|
||||
return [
|
||||
(tuple(output.outputs[0].token_ids), output.outputs[0].text)
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
|
||||
def _shutdown_llm(llm: LLM | None) -> None:
|
||||
if llm is None:
|
||||
cleanup_dist_env_and_memory()
|
||||
return
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def _log_generations(
|
||||
label: str,
|
||||
generations: list[tuple[tuple[int, ...], str]],
|
||||
) -> None:
|
||||
for i, (token_ids, text) in enumerate(generations):
|
||||
_log(f"{label} prompt {i} token ids: {list(token_ids)}")
|
||||
_log(f"{label} prompt {i} text: {text!r}")
|
||||
|
||||
|
||||
def _assert_required_words(
|
||||
label: str,
|
||||
generations: list[tuple[tuple[int, ...], str]],
|
||||
) -> None:
|
||||
for i, (_, text) in enumerate(generations):
|
||||
lowered = text.lower()
|
||||
missing = [word for word in REQUIRED_WORDS[i] if word not in lowered]
|
||||
assert not missing, (
|
||||
f"{label} prompt {i} is missing required words {missing}. "
|
||||
f"Observed text: {text!r}"
|
||||
)
|
||||
|
||||
|
||||
def _collect_soft_mismatches(
|
||||
baseline_generations: list[tuple[tuple[int, ...], str]],
|
||||
quick_reduce_generations: list[tuple[tuple[int, ...], str]],
|
||||
) -> list[str]:
|
||||
mismatches = []
|
||||
|
||||
for i, (_, text) in enumerate(baseline_generations):
|
||||
expected = RECORDED_RESPONSE_TEXTS[i]
|
||||
if text != expected:
|
||||
mismatches.append(
|
||||
f"baseline prompt {i} drifted from the recorded response.\n"
|
||||
f"expected={expected!r}\nactual={text!r}"
|
||||
)
|
||||
|
||||
for i, (_, text) in enumerate(quick_reduce_generations):
|
||||
expected = RECORDED_RESPONSE_TEXTS[i]
|
||||
if text != expected:
|
||||
mismatches.append(
|
||||
f"quick-reduce prompt {i} drifted from the recorded response.\n"
|
||||
f"expected={expected!r}\nactual={text!r}"
|
||||
)
|
||||
|
||||
for i, ((_, baseline_text), (_, quick_reduce_text)) in enumerate(
|
||||
zip(baseline_generations, quick_reduce_generations)
|
||||
):
|
||||
if baseline_text != quick_reduce_text:
|
||||
mismatches.append(
|
||||
f"baseline and quick-reduce responses differ for prompt {i}.\n"
|
||||
f"baseline={baseline_text!r}\nquick_reduce={quick_reduce_text!r}"
|
||||
)
|
||||
|
||||
return mismatches
|
||||
|
||||
|
||||
def _run_generation(
|
||||
*,
|
||||
backend: Literal["mp", "ray"],
|
||||
quant_mode: str,
|
||||
expect_quick_reduce: bool,
|
||||
) -> list[tuple[tuple[int, ...], str]]:
|
||||
llm = None
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
m.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_mode)
|
||||
model_path = _get_model_path()
|
||||
_log(
|
||||
f"starting generation: backend={backend} quant={quant_mode} "
|
||||
f"gpu_memory_utilization={E2E_GPU_MEMORY_UTILIZATION} "
|
||||
f"kv_cache_bytes={E2E_KV_CACHE_MEMORY_BYTES} model={model_path}"
|
||||
)
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
model=model_path,
|
||||
tokenizer=model_path,
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend=backend,
|
||||
dtype="half",
|
||||
enforce_eager=True,
|
||||
max_model_len=E2E_MAX_MODEL_LEN,
|
||||
max_num_seqs=len(E2E_PROMPTS),
|
||||
gpu_memory_utilization=E2E_GPU_MEMORY_UTILIZATION,
|
||||
kv_cache_memory_bytes=E2E_KV_CACHE_MEMORY_BYTES,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
if not expect_quick_reduce:
|
||||
assert llm.collective_rpc(_check_quick_reduce_disabled) == [0, 1]
|
||||
|
||||
if expect_quick_reduce:
|
||||
worker_stats = llm.collective_rpc(
|
||||
_check_tp_allreduce_uses_quick_reduce,
|
||||
args=(E2E_PREFILL_TOKENS,),
|
||||
)
|
||||
assert [stat["rank"] for stat in worker_stats] == [0, 1]
|
||||
worker_summary = "; ".join(
|
||||
"rank={rank} hidden_size={hidden_size} num_tokens={num_tokens} "
|
||||
"use_fp16_kernels={use_fp16_kernels}".format(**stat)
|
||||
for stat in worker_stats
|
||||
)
|
||||
_log(f"{backend} quick-reduce worker checks: {worker_summary}")
|
||||
|
||||
outputs = llm.generate(
|
||||
E2E_PROMPTS,
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=20,
|
||||
stop=["\nAnswer:", " Answer:"],
|
||||
),
|
||||
use_tqdm=False,
|
||||
)
|
||||
generations = _collect_generations(outputs)
|
||||
assert all(text.strip() for _, text in generations)
|
||||
_log_generations(f"{backend} {quant_mode}", generations)
|
||||
return generations
|
||||
finally:
|
||||
_shutdown_llm(llm)
|
||||
|
||||
|
||||
def _run_quick_reduce_llm_e2e_in_subprocess(
|
||||
*,
|
||||
backend: Literal["mp", "ray"],
|
||||
) -> str | None:
|
||||
_log(f"running LLM e2e: backend={backend}")
|
||||
_log_prompt_summaries()
|
||||
baseline_outputs = _run_generation(
|
||||
backend=backend,
|
||||
quant_mode="NONE",
|
||||
expect_quick_reduce=False,
|
||||
)
|
||||
quick_reduce_outputs = _run_generation(
|
||||
backend=backend,
|
||||
quant_mode="FP",
|
||||
expect_quick_reduce=True,
|
||||
)
|
||||
|
||||
_assert_required_words("baseline", baseline_outputs)
|
||||
_assert_required_words("quick-reduce", quick_reduce_outputs)
|
||||
|
||||
mismatches = _collect_soft_mismatches(baseline_outputs, quick_reduce_outputs)
|
||||
if mismatches:
|
||||
details = "\n\n".join(mismatches)
|
||||
_log(f"soft response mismatch:\n{details}")
|
||||
return details
|
||||
|
||||
_log(f"LLM e2e backend={backend} matched the recorded responses exactly")
|
||||
return None
|
||||
|
||||
|
||||
def _quick_reduce_llm_e2e_worker(
|
||||
result_queue: mp.Queue,
|
||||
backend: Literal["mp", "ray"],
|
||||
) -> None:
|
||||
try:
|
||||
xfail_reason = _run_quick_reduce_llm_e2e_in_subprocess(backend=backend)
|
||||
except Exception:
|
||||
result_queue.put({"status": "error", "reason": traceback.format_exc()})
|
||||
raise
|
||||
else:
|
||||
if xfail_reason is not None:
|
||||
result_queue.put({"status": "xfail", "reason": xfail_reason})
|
||||
else:
|
||||
result_queue.put({"status": "ok"})
|
||||
|
||||
|
||||
def run_quick_reduce_llm_e2e(
|
||||
*,
|
||||
backend: Literal["mp", "ray"],
|
||||
) -> None:
|
||||
ctx = mp.get_context("spawn")
|
||||
result_queue = ctx.Queue()
|
||||
proc = ctx.Process(
|
||||
target=_quick_reduce_llm_e2e_worker,
|
||||
args=(result_queue, backend),
|
||||
)
|
||||
proc.start()
|
||||
proc.join(timeout=600)
|
||||
|
||||
try:
|
||||
result = result_queue.get(timeout=5)
|
||||
except queue.Empty as exc:
|
||||
if proc.exitcode != 0:
|
||||
raise AssertionError(
|
||||
f"quick-reduce llm e2e subprocess failed for backend={backend} "
|
||||
f"with exit code {proc.exitcode} and produced no result"
|
||||
) from exc
|
||||
raise AssertionError(
|
||||
f"quick-reduce llm e2e subprocess produced no result for backend={backend}"
|
||||
) from exc
|
||||
|
||||
if result["status"] == "xfail":
|
||||
pytest.xfail(result["reason"])
|
||||
if result["status"] == "error":
|
||||
raise AssertionError(
|
||||
f"quick-reduce llm e2e subprocess failed for backend={backend}:\n"
|
||||
f"{result['reason']}"
|
||||
)
|
||||
|
||||
assert proc.exitcode == 0, (
|
||||
f"quick-reduce llm e2e subprocess failed for backend={backend} "
|
||||
f"with exit code {proc.exitcode}"
|
||||
)
|
||||
|
||||
|
||||
def test_quick_reduce_regime_values():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickReduceRegime
|
||||
|
||||
assert QuickReduceRegime.FP.value == 0
|
||||
assert QuickReduceRegime.INT8.value == 1
|
||||
assert QuickReduceRegime.INT6.value == 2
|
||||
assert QuickReduceRegime.INT4.value == 3
|
||||
assert QuickReduceRegime.NONE.value == 4
|
||||
|
||||
|
||||
def test_quick_reduce_regime_names():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickReduceRegime
|
||||
|
||||
assert set(QuickReduceRegime.__members__) == {"FP", "INT8", "INT6", "INT4", "NONE"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_level", QUANT_LEVELS + ["NONE"])
|
||||
def test_quick_reduce_quantization_env_var(monkeypatch, quant_level):
|
||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", quant_level)
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert quant_level == reloaded_envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION
|
||||
|
||||
|
||||
def test_quick_reduce_quantization_default(monkeypatch):
|
||||
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_QUANTIZATION", raising=False)
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_QUANTIZATION == "NONE"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cast_bf16", [True, False])
|
||||
def test_quick_reduce_cast_bf16_to_fp16_env_var(monkeypatch, cast_bf16):
|
||||
monkeypatch.setenv(
|
||||
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", "1" if cast_bf16 else "0"
|
||||
)
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 is cast_bf16
|
||||
|
||||
|
||||
def test_quick_reduce_cast_bf16_to_fp16_default(monkeypatch):
|
||||
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16", raising=False)
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16 is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_mb", [128, 512, 2048, None])
|
||||
def test_quick_reduce_max_size_env_var(monkeypatch, max_mb):
|
||||
if max_mb is None:
|
||||
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", raising=False)
|
||||
else:
|
||||
monkeypatch.setenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", str(max_mb))
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert max_mb == reloaded_envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB
|
||||
|
||||
|
||||
def test_quick_reduce_max_size_default(monkeypatch):
|
||||
monkeypatch.delenv("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", raising=False)
|
||||
|
||||
reloaded_envs = _reload_envs()
|
||||
assert reloaded_envs.VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("gcn_arch_name", "expected"),
|
||||
[
|
||||
("gfx942", True),
|
||||
("gfx950", True),
|
||||
("gfx90a", False),
|
||||
("", False),
|
||||
],
|
||||
)
|
||||
def test_quick_allreduce_rocm_arch_available(gcn_arch_name, expected):
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
|
||||
|
||||
qar = QuickAllReduce.__new__(QuickAllReduce)
|
||||
qar.disabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.device_communicators.quick_all_reduce.current_platform."
|
||||
"is_rocm",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"torch.cuda.get_device_properties",
|
||||
return_value=SimpleNamespace(gcnArchName=gcn_arch_name),
|
||||
),
|
||||
):
|
||||
assert qar._rocm_arch_available() is expected
|
||||
|
||||
|
||||
def test_quick_allreduce_rocm_arch_available_handles_probe_failure():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
|
||||
|
||||
qar = QuickAllReduce.__new__(QuickAllReduce)
|
||||
qar.disabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.device_communicators.quick_all_reduce.current_platform."
|
||||
"is_rocm",
|
||||
return_value=True,
|
||||
),
|
||||
patch("torch.cuda.get_device_properties", side_effect=RuntimeError),
|
||||
):
|
||||
assert qar._rocm_arch_available() is False
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_disabled():
|
||||
qar = _make_quick_allreduce(disabled=True)
|
||||
|
||||
inp = torch.zeros(1024, dtype=torch.float16)
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_unsupported_dtype():
|
||||
qar = _make_quick_allreduce()
|
||||
|
||||
inp = torch.zeros(1024 * 1024, dtype=torch.float32)
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_non_aligned_input():
|
||||
qar = _make_quick_allreduce()
|
||||
|
||||
inp = torch.zeros(5, dtype=torch.float16)
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_non_contiguous_input():
|
||||
qar = _make_quick_allreduce()
|
||||
|
||||
inp = torch.zeros((1024, 1024), dtype=torch.float16)[:, ::2]
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_input_smaller_than_threshold():
|
||||
qar = _make_quick_allreduce()
|
||||
|
||||
inp = torch.zeros((MB // 2) - 8, dtype=torch.float16)
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_accepts_input_at_threshold():
|
||||
qar = _make_quick_allreduce()
|
||||
|
||||
inp = torch.zeros(MB // 2, dtype=torch.float16)
|
||||
assert qar.should_quick_allreduce(inp) is True
|
||||
|
||||
|
||||
def test_quick_allreduce_rejects_input_larger_than_max_size():
|
||||
qar = _make_quick_allreduce(qr_max_size=1 * MB)
|
||||
|
||||
inp = torch.zeros(MB, dtype=torch.float16)
|
||||
assert qar.should_quick_allreduce(inp) is False
|
||||
|
||||
|
||||
def test_quick_allreduce_bf16_uses_fp16_threshold_when_cast_enabled():
|
||||
inp = torch.zeros(MB // 2, dtype=torch.bfloat16)
|
||||
|
||||
without_cast = _make_quick_allreduce(use_fp16_kernels=False)
|
||||
with_cast = _make_quick_allreduce(use_fp16_kernels=True)
|
||||
|
||||
assert without_cast.should_quick_allreduce(inp) is False
|
||||
assert with_cast.should_quick_allreduce(inp) is True
|
||||
|
||||
|
||||
def test_quick_allreduce_supported_world_sizes():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
|
||||
|
||||
assert QuickAllReduce._SUPPORTED_WORLD_SIZES == [2, 4, 8]
|
||||
|
||||
|
||||
def test_quick_allreduce_supported_dtypes():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
|
||||
|
||||
assert [torch.float16, torch.bfloat16] == QuickAllReduce._SUPPORTED_DTYPES
|
||||
|
||||
|
||||
def test_quick_allreduce_min_size_table():
|
||||
from vllm.distributed.device_communicators.quick_all_reduce import QuickAllReduce
|
||||
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
for world_size in QuickAllReduce._SUPPORTED_WORLD_SIZES:
|
||||
min_sizes = QuickAllReduce._QR_MIN_SIZE[(dtype, world_size)]
|
||||
assert len(min_sizes) == 4
|
||||
assert all(size > 0 for size in min_sizes)
|
||||
|
||||
|
||||
def test_qr_max_size():
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
max_size = ops.qr_max_size()
|
||||
assert isinstance(max_size, int)
|
||||
assert max_size > 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.device_count() < WORLD_SIZE,
|
||||
reason="requires 2 ROCm GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("quant_level", QUANT_LEVELS)
|
||||
def test_quick_allreduce_two_gpu_correctness(quant_level):
|
||||
_log(f"two-GPU correctness case: quant={quant_level}")
|
||||
_run_two_gpu_quick_allreduce_test(
|
||||
quant_level=quant_level,
|
||||
dtype_name="float16",
|
||||
cast_bf16=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.device_count() < WORLD_SIZE,
|
||||
reason="requires 2 ROCm GPUs",
|
||||
)
|
||||
def test_quick_allreduce_bf16_cast_mode():
|
||||
_log("BF16 cast case")
|
||||
_run_two_gpu_quick_allreduce_test(
|
||||
quant_level="FP",
|
||||
dtype_name="bfloat16",
|
||||
cast_bf16=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.device_count() < WORLD_SIZE,
|
||||
reason="requires 2 ROCm GPUs",
|
||||
)
|
||||
def test_quick_allreduce_llm_e2e():
|
||||
_log("LLM e2e case: backend=mp")
|
||||
run_quick_reduce_llm_e2e(backend="mp")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.device_count() < WORLD_SIZE,
|
||||
reason="requires 2 ROCm GPUs",
|
||||
)
|
||||
def test_quick_allreduce_llm_e2e_ray():
|
||||
_log("LLM e2e case: backend=ray")
|
||||
run_quick_reduce_llm_e2e(backend="ray")
|
||||
+26
-31
@@ -1363,43 +1363,38 @@ def multi_process_parallel(
|
||||
) -> None:
|
||||
import ray
|
||||
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
# NOTE: We need to set working_dir for distributed tests,
|
||||
# otherwise we may get import errors on ray workers
|
||||
# NOTE: Force ray not to use gitignore file as excluding, otherwise
|
||||
# it will not move .so files to working dir.
|
||||
# So we have to manually add some of large directories
|
||||
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
|
||||
# Using ray helps debugging the error when it failed as compared to
|
||||
# multiprocessing. For local Ray workers, putting the repo root on
|
||||
# PYTHONPATH is enough and avoids uploading the full source tree, which
|
||||
# exceeds Ray's working_dir package size limit on CI.
|
||||
env_vars = {
|
||||
"PYTHONPATH": os.pathsep.join(
|
||||
filter(None, [str(VLLM_PATH), os.environ.get("PYTHONPATH")])
|
||||
),
|
||||
**{env_var: "1" for env_var in current_platform.ray_noset_device_env_vars},
|
||||
}
|
||||
ray.init(
|
||||
runtime_env={
|
||||
"working_dir": VLLM_PATH,
|
||||
"excludes": [
|
||||
"build",
|
||||
".git",
|
||||
"cmake-build-*",
|
||||
"shellcheck",
|
||||
"dist",
|
||||
"ep_kernels_workspace",
|
||||
],
|
||||
"env_vars": env_vars,
|
||||
}
|
||||
)
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(tp_size * pp_size):
|
||||
refs.append(
|
||||
test_target.remote(
|
||||
monkeypatch,
|
||||
tp_size,
|
||||
pp_size,
|
||||
rank,
|
||||
distributed_init_port,
|
||||
),
|
||||
)
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
try:
|
||||
refs = []
|
||||
for rank in range(tp_size * pp_size):
|
||||
refs.append(
|
||||
test_target.remote(
|
||||
monkeypatch,
|
||||
tp_size,
|
||||
pp_size,
|
||||
rank,
|
||||
distributed_init_port,
|
||||
),
|
||||
)
|
||||
ray.get(refs)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
Reference in New Issue
Block a user