mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Bugfix] moe lora align kernel grid (#40131)
Signed-off-by: TheDuyIT <nduy250299@gmail.com> Signed-off-by: Jee Jee Li <jeejeelee@inferact.ai> Signed-off-by: dtnguyen <dtnguyen@nvidia.com> Co-authored-by: Jee Jee Li <jeejeelee@inferact.ai> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -390,7 +390,13 @@ __global__ void moe_lora_align_block_size_kernel(
|
||||
int32_t* __restrict__ token_mask, bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x / 2;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
|
||||
// Output buffers are indexed by lora_id (in [0, max_loras)). The grid
|
||||
// iterates one extra slot to accommodate the "-1" entry that
|
||||
// active_lora_ids may hold in position 0 for mixed base + LoRA batches;
|
||||
// guard against any other unexpected lora_id >= max_loras to avoid
|
||||
// out-of-bounds writes. This mirrors the `lora_id >= max_loras` guard in
|
||||
// the Triton _fused_moe_lora_kernel.
|
||||
if (lora_id == -1 || lora_id >= max_loras || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -420,10 +426,21 @@ __global__ void lora_count_and_sort_expert_tokens_kernel(
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
|
||||
int32_t* __restrict__ expert_map, size_t numel, int32_t num_experts,
|
||||
int32_t max_num_tokens_padded, int32_t topk_num, int32_t* token_mask,
|
||||
int32_t* lora_ids, bool has_expert_map) {
|
||||
int32_t max_loras, int32_t* lora_ids, int32_t* adapter_enabled,
|
||||
bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1) {
|
||||
// Same guard rationale as moe_lora_align_block_size_kernel. Additionally
|
||||
// skip disabled adapter slots: moe_lora_align_block_size_kernel early-returns
|
||||
// for them and leaves token_mask[lora_id, :] uninitialized (token_mask is
|
||||
// allocated with torch::empty), so running the sort loop here would traverse
|
||||
// garbage mask bits and pollute this slot's rows of sorted_token_ids and
|
||||
// cumsum_buffer. Downstream consumers already skip disabled slots, so the
|
||||
// pollution is dormant today, but the check keeps behavior symmetric with
|
||||
// the other two align kernels and avoids O(numel) wasted work per disabled
|
||||
// slot. Short-circuit evaluation ensures adapter_enabled is only indexed
|
||||
// after lora_id is confirmed to be in [0, max_loras).
|
||||
if (lora_id == -1 || lora_id >= max_loras || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -446,7 +463,8 @@ __global__ void moe_lora_align_block_size_small_batch_expert_kernel(
|
||||
int32_t* token_mask, bool has_expert_map) {
|
||||
int lora_idx = blockIdx.x;
|
||||
int lora_id = lora_ids[lora_idx];
|
||||
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
|
||||
// Same guard rationale as moe_lora_align_block_size_kernel.
|
||||
if (lora_id == -1 || lora_id >= max_loras || adapter_enabled[lora_id] == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -698,7 +716,15 @@ void moe_lora_align_block_size(
|
||||
scalar_t, fill_threads>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
|
||||
// Grid size is (max_loras + 1) because active_lora_ids has length
|
||||
// max_loras + 1: sorted-unique values of token_lora_mapping, which
|
||||
// can include -1 (base-model tokens) in addition to up to max_loras
|
||||
// real LoRA slots. Using max_loras would drop the real LoRA slot
|
||||
// when -1 is present at position 0 and leave output buffers
|
||||
// uninitialized, causing illegal memory accesses in downstream
|
||||
// MoE-LoRA kernels. This mirrors the fix made for the Triton
|
||||
// _fused_moe_lora_kernel grid in vllm-project/vllm#32277.
|
||||
kernel<<<max_loras + 1, blockDim, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
token_lora_mapping.data_ptr<int32_t>(), block_size,
|
||||
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
|
||||
@@ -722,10 +748,17 @@ void moe_lora_align_block_size(
|
||||
auto align_kernel =
|
||||
vllm::moe::moe_lora_align_block_size_kernel<scalar_t>;
|
||||
|
||||
// launch two threadblocks for each lora
|
||||
// Launch two threadblocks per LoRA slot, across max_loras + 1 slots
|
||||
// to cover the extra "-1" (base-model tokens) entry that
|
||||
// active_lora_ids may contain in addition to up to max_loras real
|
||||
// LoRA slots. Using max_loras would drop the real LoRA slot when -1
|
||||
// occupies position 0 and leave the output buffers uninitialized,
|
||||
// causing illegal memory accesses downstream. Mirrors the grid fix
|
||||
// applied to _fused_moe_lora_kernel in vllm-project/vllm#32277.
|
||||
// blockIdx.x % 2 == 0: counting experts and aligning
|
||||
// blockIdx.x % 2 == 1: filling sorted_token_ids
|
||||
align_kernel<<<max_loras * 2, blockDim, shared_mem_size, stream>>>(
|
||||
align_kernel<<<(max_loras + 1) * 2, blockDim, shared_mem_size,
|
||||
stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
token_lora_mapping.data_ptr<int32_t>(), block_size,
|
||||
expert_map.data_ptr<int32_t>(), num_experts, max_loras,
|
||||
@@ -744,7 +777,10 @@ void moe_lora_align_block_size(
|
||||
const int max_blocks = 65535;
|
||||
const int actual_blocks = std::min(num_blocks, max_blocks);
|
||||
|
||||
dim3 gridDims(max_loras, actual_blocks);
|
||||
// Same rationale as align_kernel above: iterate over max_loras + 1
|
||||
// slots so the sort kernel processes the real LoRA slot even when
|
||||
// active_lora_ids has -1 at position 0.
|
||||
dim3 gridDims(max_loras + 1, actual_blocks);
|
||||
auto sort_kernel =
|
||||
vllm::moe::lora_count_and_sort_expert_tokens_kernel<scalar_t>;
|
||||
|
||||
@@ -753,7 +789,8 @@ void moe_lora_align_block_size(
|
||||
sorted_token_ids.data_ptr<int32_t>(), cumsum.data_ptr<int32_t>(),
|
||||
expert_map.data_ptr<int32_t>(), topk_ids.numel(), num_experts,
|
||||
max_num_tokens_padded, topk_num, token_mask.data_ptr<int32_t>(),
|
||||
lora_ids.data_ptr<int32_t>(), has_expert_map);
|
||||
max_loras, lora_ids.data_ptr<int32_t>(),
|
||||
adapter_enabled.data_ptr<int32_t>(), has_expert_map);
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -104,5 +104,222 @@ def test_moe_lora_align_block_size(
|
||||
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
|
||||
|
||||
|
||||
# Sentinel values for the regression tests below. Distinctive out-of-domain
|
||||
# ints so that "kernel never wrote this slot" is directly observable: the
|
||||
# kernel only ever writes a real expert id in [0, num_experts) or -1
|
||||
# (expert_ids), a token index or the `numel` padding value (sorted_token_ids),
|
||||
# and a block-aligned cumsum count (num_tokens_post_pad).
|
||||
SENTINEL_EXPERT = -2
|
||||
SENTINEL_TOKEN = -7
|
||||
SENTINEL_NPAD = -13
|
||||
|
||||
|
||||
def _build_and_run_align(
|
||||
*,
|
||||
num_lora_tokens,
|
||||
num_base_tokens,
|
||||
max_loras,
|
||||
num_experts=64,
|
||||
topk_num=6,
|
||||
block_size=16,
|
||||
lora_ids_override=None,
|
||||
disabled_slots=(),
|
||||
seed=1,
|
||||
):
|
||||
"""Build inputs the way ``LoRAKernelMeta.prepare_tensors`` does, run
|
||||
``moe_lora_align_block_size``, and return a dict of result tensors plus
|
||||
derived sizes. Output buffers are pre-filled with ``SENTINEL_*`` so
|
||||
callers can assert which slots the kernel did / did not touch.
|
||||
|
||||
Tokens are assigned to LoRA slot 0 (first ``num_lora_tokens``) then -1
|
||||
(remaining ``num_base_tokens``), matching the "mixed base + 1 LoRA"
|
||||
shape used to repro vllm-project/vllm#32235.
|
||||
|
||||
``lora_ids_override``: optional 1-D int tensor of length ``max_loras+1``
|
||||
used verbatim. Default mirrors ``prepare_tensors`` (sorted-unique into
|
||||
the head, -1 tail).
|
||||
``disabled_slots``: iterable of slot indices to clear in ``adapter_enabled``.
|
||||
"""
|
||||
random.seed(seed)
|
||||
num_tokens = num_lora_tokens + num_base_tokens
|
||||
assert num_tokens > 0, "test requires at least one token"
|
||||
|
||||
topk_ids = torch.zeros((num_tokens, topk_num), dtype=torch.int32)
|
||||
token_lora_mapping = torch.empty((num_tokens,), dtype=torch.int32)
|
||||
for i in range(num_tokens):
|
||||
pool = list(range(num_experts))
|
||||
random.shuffle(pool)
|
||||
for j in range(topk_num):
|
||||
topk_ids[i, j] = pool[j]
|
||||
token_lora_mapping[i] = 0 if i < num_lora_tokens else -1
|
||||
topk_ids = topk_ids.to(DEVICE_TYPE)
|
||||
token_lora_mapping = token_lora_mapping.to(DEVICE_TYPE)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
if topk_ids.numel() < num_experts:
|
||||
max_num_tokens_padded = topk_ids.numel() * block_size
|
||||
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
|
||||
|
||||
if lora_ids_override is None:
|
||||
lora_ids = torch.full(
|
||||
(max_loras + 1,), -1, dtype=torch.int32, device=DEVICE_TYPE
|
||||
)
|
||||
unique_ids = torch.unique(token_lora_mapping, sorted=True)
|
||||
lora_ids[: unique_ids.numel()] = unique_ids.to(torch.int32)
|
||||
else:
|
||||
assert lora_ids_override.numel() == max_loras + 1
|
||||
lora_ids = lora_ids_override.to(dtype=torch.int32, device=DEVICE_TYPE)
|
||||
|
||||
adapter_enabled = torch.ones(
|
||||
(max_loras + 1,), dtype=torch.int32, device=DEVICE_TYPE
|
||||
)
|
||||
for slot in disabled_slots:
|
||||
adapter_enabled[slot] = 0
|
||||
|
||||
sorted_token_ids = torch.full(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
SENTINEL_TOKEN,
|
||||
dtype=torch.int32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
expert_ids = torch.full(
|
||||
(max_loras * max_num_m_blocks,),
|
||||
SENTINEL_EXPERT,
|
||||
dtype=torch.int32,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
num_tokens_post_pad = torch.full(
|
||||
(max_loras,), SENTINEL_NPAD, dtype=torch.int32, device=DEVICE_TYPE
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
|
||||
return {
|
||||
"lora_ids": lora_ids,
|
||||
"sorted_token_ids": sorted_token_ids,
|
||||
"expert_ids": expert_ids,
|
||||
"num_tokens_post_pad": num_tokens_post_pad,
|
||||
"max_num_tokens_padded": max_num_tokens_padded,
|
||||
"block_size": block_size,
|
||||
"max_loras": max_loras,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"max_loras",
|
||||
[
|
||||
1,
|
||||
2,
|
||||
],
|
||||
)
|
||||
def test_moe_lora_align_block_size_mixed_base_and_lora(max_loras):
|
||||
"""Regression test for issue #32235: real LoRA slot must not be skipped
|
||||
when ``active_lora_ids`` has -1 at position 0."""
|
||||
out = _build_and_run_align(
|
||||
num_lora_tokens=8, num_base_tokens=8, max_loras=max_loras
|
||||
)
|
||||
|
||||
# Sanity check on the layout being tested.
|
||||
assert out["lora_ids"][0].item() == -1, (
|
||||
"prepare_tensors layout mismatch: -1 expected at position 0 for mixed batch"
|
||||
)
|
||||
|
||||
real_slot = 0
|
||||
post_pad = out["num_tokens_post_pad"][real_slot].item()
|
||||
assert post_pad != SENTINEL_NPAD, (
|
||||
f"num_tokens_post_pad[{real_slot}] was never written by the kernel; "
|
||||
"the align kernel skipped the real LoRA slot."
|
||||
)
|
||||
assert (
|
||||
0 < post_pad <= out["max_num_tokens_padded"]
|
||||
and post_pad % out["block_size"] == 0
|
||||
), f"num_tokens_post_pad[{real_slot}]={post_pad} is not a valid block-aligned count"
|
||||
|
||||
expert_row = out["expert_ids"].view(max_loras, -1)[real_slot]
|
||||
assert (expert_row != SENTINEL_EXPERT).all(), (
|
||||
f"expert_ids row for slot {real_slot} has unwritten sentinel entries; "
|
||||
"the align kernel skipped the real LoRA slot."
|
||||
)
|
||||
|
||||
sorted_row = out["sorted_token_ids"].view(max_loras, -1)[real_slot]
|
||||
assert (sorted_row != SENTINEL_TOKEN).all(), (
|
||||
f"sorted_token_ids row for slot {real_slot} has unwritten sentinel "
|
||||
"entries; the align kernel skipped the real LoRA slot."
|
||||
)
|
||||
|
||||
|
||||
def test_moe_lora_align_block_size_disabled_adapter_untouched():
|
||||
"""Disabled-adapter slot rows must remain untouched by all three align
|
||||
kernels. Pins the invariant protected by the ``adapter_enabled`` guard
|
||||
in ``lora_count_and_sort_expert_tokens_kernel``: without it the sort
|
||||
kernel reads uninitialized ``token_mask`` values for disabled slots and
|
||||
pollutes ``sorted_token_ids`` / ``cumsum_buffer``."""
|
||||
max_loras = 1
|
||||
out = _build_and_run_align(
|
||||
num_lora_tokens=16,
|
||||
num_base_tokens=0,
|
||||
max_loras=max_loras,
|
||||
disabled_slots=(0,),
|
||||
)
|
||||
# Sanity: slot 0 IS present in active_lora_ids (otherwise we would only
|
||||
# exercise the lora_id == -1 / >= max_loras guards).
|
||||
assert (out["lora_ids"] == 0).any().item()
|
||||
|
||||
assert out["num_tokens_post_pad"][0].item() == SENTINEL_NPAD, (
|
||||
"num_tokens_post_pad[0] was modified for a disabled adapter slot."
|
||||
)
|
||||
expert_row = out["expert_ids"].view(max_loras, -1)[0]
|
||||
assert (expert_row == SENTINEL_EXPERT).all(), (
|
||||
"expert_ids row for disabled slot 0 was partially written."
|
||||
)
|
||||
# Row specifically protected by the sort-kernel adapter_enabled guard.
|
||||
sorted_row = out["sorted_token_ids"].view(max_loras, -1)[0]
|
||||
assert (sorted_row == SENTINEL_TOKEN).all(), (
|
||||
"sorted_token_ids row for disabled slot 0 was polluted by the sort "
|
||||
"kernel; lora_count_and_sort_expert_tokens_kernel must skip "
|
||||
"adapter_enabled == 0 slots."
|
||||
)
|
||||
|
||||
|
||||
def test_moe_lora_align_block_size_lora_id_oob_guard():
|
||||
"""Regression test for the ``lora_id >= max_loras`` guard.
|
||||
|
||||
Production ``LoRAKernelMeta.prepare_tensors`` pre-fills the tail of
|
||||
``active_lora_ids`` with -1, so the existing ``lora_id == -1`` check
|
||||
covers the extra slot. This test bypasses that invariant and injects
|
||||
an out-of-range value (5 with max_loras=1) at the tail to verify the
|
||||
explicit guard prevents OOB reads against ``adapter_enabled`` and
|
||||
OOB writes against the max_loras-sized output buffers. Without the
|
||||
guard, an illegal-memory-access would surface on the next CUDA sync.
|
||||
"""
|
||||
max_loras = 1
|
||||
lora_ids_override = torch.tensor([0, 5], dtype=torch.int32)
|
||||
out = _build_and_run_align(
|
||||
num_lora_tokens=16,
|
||||
num_base_tokens=0,
|
||||
max_loras=max_loras,
|
||||
lora_ids_override=lora_ids_override,
|
||||
)
|
||||
# The .item() call below syncs and would surface any async
|
||||
# illegal-memory-access from the OOB iteration.
|
||||
assert out["num_tokens_post_pad"][0].item() != SENTINEL_NPAD, (
|
||||
"real LoRA slot 0 was skipped by the align kernel"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user