mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
feat(cpu): add CPU support for draft model speculative decoding (#32662)
Signed-off-by: R <Ganesh.R@amd.com>
This commit is contained in:
@@ -349,6 +349,7 @@ endif()
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/spec_decode_utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
@@ -383,6 +384,7 @@ if (ENABLE_X86_ISA)
|
||||
"csrc/cpu/cpu_wna16.cpp"
|
||||
"csrc/cpu/cpu_fused_moe.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/spec_decode_utils.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
@@ -395,6 +397,7 @@ if (ENABLE_X86_ISA)
|
||||
|
||||
set(VLLM_EXT_SRC_AVX2
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/spec_decode_utils.cpp"
|
||||
"csrc/cpu/cpu_attn.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
# TODO: Remove these files
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace cpu_utils {
|
||||
|
||||
void eagle_prepare_inputs_padded_kernel_impl(
|
||||
const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& valid_sampled_tokens_count,
|
||||
const torch::Tensor& query_start_loc_gpu,
|
||||
torch::Tensor& token_indices_to_sample,
|
||||
torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs) {
|
||||
const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr<int64_t>();
|
||||
const int64_t* valid_count_ptr =
|
||||
valid_sampled_tokens_count.data_ptr<int64_t>();
|
||||
const int32_t* query_loc_ptr = query_start_loc_gpu.data_ptr<int32_t>();
|
||||
int32_t* indices_out_ptr = token_indices_to_sample.data_ptr<int32_t>();
|
||||
int64_t* rejected_out_ptr = num_rejected_tokens_gpu.data_ptr<int64_t>();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) {
|
||||
int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1];
|
||||
int64_t num_draft_tokens = cu_draft_ptr[req_idx] - start_idx;
|
||||
int64_t num_valid_tokens = valid_count_ptr[req_idx];
|
||||
|
||||
int64_t num_rejected = 0;
|
||||
if (num_draft_tokens > 0) {
|
||||
num_rejected = num_draft_tokens + 1 - num_valid_tokens;
|
||||
}
|
||||
|
||||
int32_t q_last_tok_idx = query_loc_ptr[req_idx + 1] - 1;
|
||||
int32_t index_to_sample = q_last_tok_idx - num_rejected;
|
||||
|
||||
indices_out_ptr[req_idx] = index_to_sample;
|
||||
rejected_out_ptr[req_idx] = num_rejected;
|
||||
}
|
||||
}
|
||||
|
||||
void eagle_prepare_next_token_padded_kernel_impl(
|
||||
const torch::Tensor& sampled_token_ids,
|
||||
const torch::Tensor& discard_request_mask,
|
||||
const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids,
|
||||
torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size,
|
||||
const int64_t num_sampled_tokens_per_req, const int64_t num_reqs) {
|
||||
const int64_t* sampled_ids_ptr = sampled_token_ids.data_ptr<int64_t>();
|
||||
const bool* discard_mask_ptr = discard_request_mask.data_ptr<bool>();
|
||||
const int64_t* backup_ids_ptr = backup_next_token_ids.data_ptr<int64_t>();
|
||||
int64_t* next_ids_out_ptr = next_token_ids.data_ptr<int64_t>();
|
||||
int64_t* valid_count_out_ptr = valid_sampled_tokens_count.data_ptr<int64_t>();
|
||||
|
||||
const int64_t stride = sampled_token_ids.stride(0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) {
|
||||
const int64_t* row_ptr = sampled_ids_ptr + req_idx * stride;
|
||||
int64_t valid_count = 0;
|
||||
int64_t last_valid_token = -1;
|
||||
|
||||
for (int64_t pos = 0; pos < num_sampled_tokens_per_req; ++pos) {
|
||||
int64_t token = row_ptr[pos];
|
||||
if (token != -1 && token < vocab_size) {
|
||||
valid_count++;
|
||||
last_valid_token = token;
|
||||
}
|
||||
}
|
||||
|
||||
bool discard = discard_mask_ptr[req_idx];
|
||||
if (discard) {
|
||||
next_ids_out_ptr[req_idx] = backup_ids_ptr[req_idx];
|
||||
valid_count_out_ptr[req_idx] = 0;
|
||||
} else {
|
||||
next_ids_out_ptr[req_idx] =
|
||||
(valid_count > 0) ? last_valid_token : backup_ids_ptr[req_idx];
|
||||
valid_count_out_ptr[req_idx] = valid_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void eagle_step_slot_mapping_metadata_kernel_impl(
|
||||
const torch::Tensor& positions, const torch::Tensor& block_table,
|
||||
torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions,
|
||||
torch::Tensor& out_slot_mapping, const int64_t block_size,
|
||||
const int64_t max_model_len, const int64_t PAD_ID) {
|
||||
const int64_t batch_size = positions.size(0);
|
||||
const int64_t input_batch_size = out_slot_mapping.size(0);
|
||||
|
||||
const int64_t* pos_ptr = positions.data_ptr<int64_t>();
|
||||
const int32_t* bt_ptr = block_table.data_ptr<int32_t>();
|
||||
int32_t* seq_lens_ptr = seq_lens.data_ptr<int32_t>();
|
||||
int64_t* out_clamped_ptr = out_clamped_positions.data_ptr<int64_t>();
|
||||
int64_t* out_slot_ptr = out_slot_mapping.data_ptr<int64_t>();
|
||||
|
||||
const int64_t bt_stride = block_table.stride(0);
|
||||
const int64_t n_blocks_per_req = block_table.size(1);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < input_batch_size; ++req_idx) {
|
||||
if (req_idx >= batch_size) {
|
||||
out_slot_ptr[req_idx] = PAD_ID;
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t position = pos_ptr[req_idx];
|
||||
int64_t new_position = position + 1;
|
||||
bool exceeds_max = new_position >= max_model_len;
|
||||
int64_t clamped_position = exceeds_max ? 0 : new_position;
|
||||
|
||||
out_clamped_ptr[req_idx] = clamped_position;
|
||||
|
||||
int64_t block_number = clamped_position / block_size;
|
||||
block_number = std::min(block_number, n_blocks_per_req - 1);
|
||||
int32_t block_id = bt_ptr[req_idx * bt_stride + block_number];
|
||||
int64_t slot_id = block_id * block_size + (clamped_position % block_size);
|
||||
out_slot_ptr[req_idx] = exceeds_max ? PAD_ID : slot_id;
|
||||
|
||||
int32_t seq_len = seq_lens_ptr[req_idx];
|
||||
int32_t new_seq_len = exceeds_max ? 1 : (seq_len + 1);
|
||||
new_seq_len = std::min(new_seq_len, static_cast<int32_t>(max_model_len));
|
||||
seq_lens_ptr[req_idx] = new_seq_len;
|
||||
}
|
||||
}
|
||||
|
||||
void copy_and_expand_eagle_inputs_kernel_impl(
|
||||
const torch::Tensor& target_token_ids,
|
||||
const torch::Tensor& target_positions, const torch::Tensor& next_token_ids,
|
||||
torch::Tensor& out_input_ids, torch::Tensor& out_positions,
|
||||
torch::Tensor& out_is_rejected_token_mask,
|
||||
torch::Tensor& out_is_masked_token_mask,
|
||||
torch::Tensor& out_new_token_indices,
|
||||
torch::Tensor& out_hidden_state_mapping,
|
||||
const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc,
|
||||
const int64_t padding_token_id, const int64_t parallel_drafting_token_id,
|
||||
const int64_t total_input_tokens,
|
||||
const int64_t num_padding_slots_per_request, const bool shift_input_ids) {
|
||||
const int64_t num_reqs = query_end_loc.size(0);
|
||||
|
||||
const int64_t* target_ids_ptr = target_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* target_pos_ptr = target_positions.data_ptr<int64_t>();
|
||||
const int64_t* next_ids_ptr = next_token_ids.data_ptr<int64_t>();
|
||||
const int32_t* query_start_ptr = query_start_loc.data_ptr<int32_t>();
|
||||
const int32_t* query_end_ptr = query_end_loc.data_ptr<int32_t>();
|
||||
|
||||
int64_t* out_ids_ptr = out_input_ids.data_ptr<int64_t>();
|
||||
int64_t* out_pos_ptr = out_positions.data_ptr<int64_t>();
|
||||
bool* out_rej_mask_ptr = out_is_rejected_token_mask.data_ptr<bool>();
|
||||
bool* out_mask_ptr = out_is_masked_token_mask.data_ptr<bool>();
|
||||
int32_t* out_new_idx_ptr = out_new_token_indices.data_ptr<int32_t>();
|
||||
int32_t* out_hidden_map_ptr = out_hidden_state_mapping.data_ptr<int32_t>();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < num_reqs; ++req_idx) {
|
||||
int32_t q_start = query_start_ptr[req_idx];
|
||||
int32_t next_q_start = query_start_ptr[req_idx + 1];
|
||||
int32_t q_end = query_end_ptr[req_idx];
|
||||
|
||||
int64_t num_valid_tokens =
|
||||
shift_input_ids ? (q_end - q_start) : (q_end - q_start + 1);
|
||||
int64_t input_offset = shift_input_ids ? 1 : 0;
|
||||
|
||||
int64_t out_start = q_start + req_idx * (num_padding_slots_per_request -
|
||||
(shift_input_ids ? 1 : 0));
|
||||
int64_t num_rejected = next_q_start - q_end - 1;
|
||||
int64_t total_output_tokens =
|
||||
num_valid_tokens + num_padding_slots_per_request + num_rejected;
|
||||
|
||||
int64_t start_pos = target_pos_ptr[q_start];
|
||||
int64_t bonus_token = next_ids_ptr[req_idx];
|
||||
|
||||
for (int64_t j = 0; j < total_output_tokens; ++j) {
|
||||
int64_t out_idx = out_start + j;
|
||||
bool is_valid = j < num_valid_tokens;
|
||||
bool is_bonus = j == num_valid_tokens;
|
||||
bool is_parallel = (j > num_valid_tokens) &&
|
||||
(j < num_valid_tokens + num_padding_slots_per_request);
|
||||
bool is_rejected = j >= num_valid_tokens + num_padding_slots_per_request;
|
||||
|
||||
int64_t in_idx =
|
||||
std::min(static_cast<int64_t>(q_start + input_offset + j),
|
||||
total_input_tokens - 1);
|
||||
|
||||
int64_t token_id = padding_token_id;
|
||||
if (is_valid)
|
||||
token_id = target_ids_ptr[in_idx];
|
||||
else if (is_bonus)
|
||||
token_id = bonus_token;
|
||||
else if (is_parallel)
|
||||
token_id = parallel_drafting_token_id;
|
||||
|
||||
out_ids_ptr[out_idx] = token_id;
|
||||
out_pos_ptr[out_idx] = is_rejected ? 0 : (start_pos + j);
|
||||
out_rej_mask_ptr[out_idx] = is_rejected;
|
||||
out_mask_ptr[out_idx] = is_parallel;
|
||||
|
||||
if (is_bonus || is_parallel) {
|
||||
int64_t new_token_local_idx = j - num_valid_tokens;
|
||||
int64_t new_token_out_idx =
|
||||
req_idx * num_padding_slots_per_request + new_token_local_idx;
|
||||
out_new_idx_ptr[new_token_out_idx] = out_idx;
|
||||
}
|
||||
}
|
||||
|
||||
if (shift_input_ids) {
|
||||
int64_t n_input = next_q_start - q_start;
|
||||
for (int64_t j = 0; j < n_input; ++j) {
|
||||
out_hidden_map_ptr[q_start + j] = out_start + j;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rejection_greedy_sample_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax,
|
||||
const torch::Tensor& bonus_token_ids,
|
||||
const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len) {
|
||||
const int64_t batch_size = cu_num_draft_tokens.size(0);
|
||||
|
||||
int64_t* out_ptr = output_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr<int64_t>();
|
||||
const int64_t* draft_ids_ptr = draft_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* target_argmax_ptr = target_argmax.data_ptr<int64_t>();
|
||||
const int64_t* bonus_ids_ptr = bonus_token_ids.data_ptr<int64_t>();
|
||||
const bool* greedy_ptr =
|
||||
is_greedy.has_value() ? is_greedy.value().data_ptr<bool>() : nullptr;
|
||||
|
||||
const int64_t out_stride = output_token_ids.stride(0);
|
||||
const int64_t bonus_stride = bonus_token_ids.stride(0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) {
|
||||
if (greedy_ptr && !greedy_ptr[req_idx]) continue;
|
||||
|
||||
int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1];
|
||||
int64_t end_idx = cu_draft_ptr[req_idx];
|
||||
int64_t num_draft_tokens = end_idx - start_idx;
|
||||
|
||||
bool rejected = false;
|
||||
for (int64_t pos = 0; pos < num_draft_tokens; ++pos) {
|
||||
int64_t target_id = target_argmax_ptr[start_idx + pos];
|
||||
out_ptr[req_idx * out_stride + pos] = target_id;
|
||||
|
||||
if (draft_ids_ptr[start_idx + pos] != target_id) {
|
||||
rejected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rejected) {
|
||||
out_ptr[req_idx * out_stride + num_draft_tokens] =
|
||||
bonus_ids_ptr[req_idx * bonus_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void rejection_random_sample_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids,
|
||||
const std::optional<torch::Tensor>& draft_probs,
|
||||
const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids,
|
||||
const torch::Tensor& recovered_token_ids,
|
||||
const torch::Tensor& uniform_probs,
|
||||
const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len,
|
||||
const int64_t vocab_size, const bool no_draft_probs) {
|
||||
const int64_t batch_size = cu_num_draft_tokens.size(0);
|
||||
|
||||
int64_t* out_ptr = output_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr<int64_t>();
|
||||
const int64_t* draft_ids_ptr = draft_token_ids.data_ptr<int64_t>();
|
||||
const float* draft_probs_ptr =
|
||||
no_draft_probs ? nullptr : draft_probs.value().data_ptr<float>();
|
||||
const float* target_probs_ptr = target_probs.data_ptr<float>();
|
||||
const int64_t* bonus_ids_ptr = bonus_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* recovered_ids_ptr = recovered_token_ids.data_ptr<int64_t>();
|
||||
const float* uniform_probs_ptr = uniform_probs.data_ptr<float>();
|
||||
const bool* greedy_ptr =
|
||||
is_greedy.has_value() ? is_greedy.value().data_ptr<bool>() : nullptr;
|
||||
|
||||
const int64_t out_stride = output_token_ids.stride(0);
|
||||
const int64_t bonus_stride = bonus_token_ids.stride(0);
|
||||
const int64_t target_stride = target_probs.stride(0);
|
||||
const int64_t draft_probs_stride =
|
||||
no_draft_probs ? 0 : draft_probs.value().stride(0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) {
|
||||
if (greedy_ptr && greedy_ptr[req_idx]) continue;
|
||||
|
||||
int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1];
|
||||
int64_t end_idx = cu_draft_ptr[req_idx];
|
||||
int64_t num_draft_tokens = end_idx - start_idx;
|
||||
|
||||
bool rejected = false;
|
||||
for (int64_t pos = 0; pos < num_draft_tokens; ++pos) {
|
||||
int64_t token_idx = start_idx + pos;
|
||||
int64_t draft_id = draft_ids_ptr[token_idx];
|
||||
|
||||
float p = target_probs_ptr[token_idx * target_stride + draft_id];
|
||||
float q =
|
||||
no_draft_probs
|
||||
? 1.0f
|
||||
: draft_probs_ptr[token_idx * draft_probs_stride + draft_id];
|
||||
float uniform_p = uniform_probs_ptr[token_idx];
|
||||
|
||||
float ratio = (q > 0.0f) ? (p / q) : 0.0f;
|
||||
|
||||
if (ratio >= uniform_p) {
|
||||
out_ptr[req_idx * out_stride + pos] = draft_id;
|
||||
} else {
|
||||
out_ptr[req_idx * out_stride + pos] = recovered_ids_ptr[token_idx];
|
||||
rejected = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!rejected) {
|
||||
out_ptr[req_idx * out_stride + num_draft_tokens] =
|
||||
bonus_ids_ptr[req_idx * bonus_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& cu_num_tokens,
|
||||
const int64_t replace_from, const int64_t replace_to) {
|
||||
const int64_t batch_size = cu_num_tokens.size(0);
|
||||
const int64_t* cu_tokens_ptr = cu_num_tokens.data_ptr<int64_t>();
|
||||
|
||||
int64_t* out_ptr = output.data_ptr<int64_t>();
|
||||
const int64_t* in_ptr = input.data_ptr<int64_t>();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) {
|
||||
int64_t start_idx = req_idx == 0 ? 0 : cu_tokens_ptr[req_idx - 1];
|
||||
int64_t end_idx = cu_tokens_ptr[req_idx];
|
||||
int64_t val = in_ptr[req_idx];
|
||||
|
||||
if (val == replace_from) {
|
||||
val = replace_to;
|
||||
}
|
||||
|
||||
for (int64_t i = start_idx; i < end_idx; ++i) {
|
||||
out_ptr[i] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void sample_recovered_tokens_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids,
|
||||
const std::optional<torch::Tensor>& draft_probs,
|
||||
const torch::Tensor& target_probs, const torch::Tensor& inv_q,
|
||||
const int64_t vocab_size, const bool no_draft_probs) {
|
||||
const int64_t batch_size = cu_num_draft_tokens.size(0);
|
||||
|
||||
int64_t* out_ptr = output_token_ids.data_ptr<int64_t>();
|
||||
const int64_t* cu_draft_ptr = cu_num_draft_tokens.data_ptr<int64_t>();
|
||||
const int64_t* draft_ids_ptr = draft_token_ids.data_ptr<int64_t>();
|
||||
const float* draft_probs_ptr =
|
||||
no_draft_probs ? nullptr : draft_probs.value().data_ptr<float>();
|
||||
const float* target_probs_ptr = target_probs.data_ptr<float>();
|
||||
const float* inv_q_ptr = inv_q.data_ptr<float>();
|
||||
|
||||
const int64_t target_stride = target_probs.stride(0);
|
||||
const int64_t draft_probs_stride =
|
||||
no_draft_probs ? 0 : draft_probs.value().stride(0);
|
||||
const int64_t inv_q_stride = inv_q.stride(0);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t req_idx = 0; req_idx < batch_size; ++req_idx) {
|
||||
int64_t start_idx = req_idx == 0 ? 0 : cu_draft_ptr[req_idx - 1];
|
||||
int64_t end_idx = cu_draft_ptr[req_idx];
|
||||
int64_t num_draft_tokens = end_idx - start_idx;
|
||||
|
||||
const float* req_inv_q = inv_q_ptr + req_idx * inv_q_stride;
|
||||
|
||||
for (int64_t pos = 0; pos < num_draft_tokens; ++pos) {
|
||||
int64_t token_idx = start_idx + pos;
|
||||
int64_t draft_id = draft_ids_ptr[token_idx];
|
||||
|
||||
const float* token_target_probs =
|
||||
target_probs_ptr + token_idx * target_stride;
|
||||
const float* token_draft_probs =
|
||||
no_draft_probs ? nullptr
|
||||
: (draft_probs_ptr + token_idx * draft_probs_stride);
|
||||
|
||||
int64_t best_id = 0;
|
||||
float best_val = -1.0f;
|
||||
|
||||
for (int64_t v = 0; v < vocab_size; ++v) {
|
||||
float prob = token_target_probs[v];
|
||||
if (no_draft_probs) {
|
||||
if (v == draft_id) prob = 0.0f;
|
||||
} else {
|
||||
float diff = prob - token_draft_probs[v];
|
||||
prob = diff > 0.0f ? diff : 0.0f;
|
||||
}
|
||||
|
||||
float val = prob * req_inv_q[v];
|
||||
if (val > best_val) {
|
||||
best_val = val;
|
||||
best_id = v;
|
||||
}
|
||||
}
|
||||
out_ptr[token_idx] = best_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cpu_utils
|
||||
@@ -138,6 +138,61 @@ void compute_slot_mapping_kernel_impl(const torch::Tensor query_start_loc,
|
||||
torch::Tensor slot_mapping,
|
||||
const int64_t block_size);
|
||||
|
||||
namespace cpu_utils {
|
||||
void eagle_prepare_inputs_padded_kernel_impl(
|
||||
const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& valid_sampled_tokens_count,
|
||||
const torch::Tensor& query_start_loc_gpu,
|
||||
torch::Tensor& token_indices_to_sample,
|
||||
torch::Tensor& num_rejected_tokens_gpu, const int64_t num_reqs);
|
||||
void eagle_prepare_next_token_padded_kernel_impl(
|
||||
const torch::Tensor& sampled_token_ids,
|
||||
const torch::Tensor& discard_request_mask,
|
||||
const torch::Tensor& backup_next_token_ids, torch::Tensor& next_token_ids,
|
||||
torch::Tensor& valid_sampled_tokens_count, const int64_t vocab_size,
|
||||
const int64_t num_sampled_tokens_per_req, const int64_t num_reqs);
|
||||
void eagle_step_slot_mapping_metadata_kernel_impl(
|
||||
const torch::Tensor& positions, const torch::Tensor& block_table,
|
||||
torch::Tensor& seq_lens, torch::Tensor& out_clamped_positions,
|
||||
torch::Tensor& out_slot_mapping, const int64_t block_size,
|
||||
const int64_t max_model_len, const int64_t PAD_ID);
|
||||
void copy_and_expand_eagle_inputs_kernel_impl(
|
||||
const torch::Tensor& target_token_ids,
|
||||
const torch::Tensor& target_positions, const torch::Tensor& next_token_ids,
|
||||
torch::Tensor& out_input_ids, torch::Tensor& out_positions,
|
||||
torch::Tensor& out_is_rejected_token_mask,
|
||||
torch::Tensor& out_is_masked_token_mask,
|
||||
torch::Tensor& out_new_token_indices,
|
||||
torch::Tensor& out_hidden_state_mapping,
|
||||
const torch::Tensor& query_start_loc, const torch::Tensor& query_end_loc,
|
||||
const int64_t padding_token_id, const int64_t parallel_drafting_token_id,
|
||||
const int64_t total_input_tokens,
|
||||
const int64_t num_padding_slots_per_request, const bool shift_input_ids);
|
||||
void rejection_greedy_sample_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids, const torch::Tensor& target_argmax,
|
||||
const torch::Tensor& bonus_token_ids,
|
||||
const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len);
|
||||
void rejection_random_sample_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids,
|
||||
const std::optional<torch::Tensor>& draft_probs,
|
||||
const torch::Tensor& target_probs, const torch::Tensor& bonus_token_ids,
|
||||
const torch::Tensor& recovered_token_ids,
|
||||
const torch::Tensor& uniform_probs,
|
||||
const std::optional<torch::Tensor>& is_greedy, const int64_t max_spec_len,
|
||||
const int64_t vocab_size, const bool no_draft_probs);
|
||||
void expand_kernel_impl(torch::Tensor& output, const torch::Tensor& input,
|
||||
const torch::Tensor& cu_num_tokens,
|
||||
const int64_t replace_from, const int64_t replace_to);
|
||||
void sample_recovered_tokens_kernel_impl(
|
||||
torch::Tensor& output_token_ids, const torch::Tensor& cu_num_draft_tokens,
|
||||
const torch::Tensor& draft_token_ids,
|
||||
const std::optional<torch::Tensor>& draft_probs,
|
||||
const torch::Tensor& target_probs, const torch::Tensor& inv_q,
|
||||
const int64_t vocab_size, const bool no_draft_probs);
|
||||
} // namespace cpu_utils
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@@ -363,6 +418,70 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"positions, Tensor block_table, Tensor(a3!) slot_mapping, SymInt "
|
||||
"block_size) -> ()",
|
||||
&compute_slot_mapping_kernel_impl);
|
||||
|
||||
// Speculative decoding kernels
|
||||
ops.def(
|
||||
"eagle_prepare_inputs_padded_kernel_impl(Tensor cu_num_draft_tokens, "
|
||||
"Tensor valid_sampled_tokens_count, Tensor query_start_loc_gpu, "
|
||||
"Tensor(a3!) token_indices_to_sample, "
|
||||
"Tensor(a4!) num_rejected_tokens_gpu, "
|
||||
"SymInt num_reqs) -> ()",
|
||||
&cpu_utils::eagle_prepare_inputs_padded_kernel_impl);
|
||||
ops.def(
|
||||
"eagle_prepare_next_token_padded_kernel_impl("
|
||||
"Tensor sampled_token_ids, Tensor discard_request_mask, "
|
||||
"Tensor backup_next_token_ids, Tensor(a3!) next_token_ids, "
|
||||
"Tensor(a4!) valid_sampled_tokens_count, SymInt vocab_size, "
|
||||
"SymInt num_sampled_tokens_per_req, SymInt num_reqs) -> ()",
|
||||
&cpu_utils::eagle_prepare_next_token_padded_kernel_impl);
|
||||
ops.def(
|
||||
"eagle_step_slot_mapping_metadata_kernel_impl("
|
||||
"Tensor positions, Tensor block_table, Tensor(a2!) seq_lens, "
|
||||
"Tensor(a3!) out_clamped_positions, Tensor(a4!) out_slot_mapping, "
|
||||
"SymInt block_size, SymInt max_model_len, SymInt PAD_ID) -> ()",
|
||||
&cpu_utils::eagle_step_slot_mapping_metadata_kernel_impl);
|
||||
ops.def(
|
||||
"copy_and_expand_eagle_inputs_kernel_impl("
|
||||
"Tensor target_token_ids, Tensor target_positions, "
|
||||
"Tensor next_token_ids, Tensor(a3!) out_input_ids, "
|
||||
"Tensor(a4!) out_positions, "
|
||||
"Tensor(a5!) out_is_rejected_token_mask, "
|
||||
"Tensor(a6!) out_is_masked_token_mask, "
|
||||
"Tensor(a7!) out_new_token_indices, "
|
||||
"Tensor(a8!) out_hidden_state_mapping, "
|
||||
"Tensor query_start_loc, Tensor query_end_loc, "
|
||||
"SymInt padding_token_id, SymInt parallel_drafting_token_id, "
|
||||
"SymInt total_input_tokens, SymInt num_padding_slots_per_request, "
|
||||
"bool shift_input_ids) -> ()",
|
||||
&cpu_utils::copy_and_expand_eagle_inputs_kernel_impl);
|
||||
ops.def(
|
||||
"rejection_greedy_sample_kernel_impl("
|
||||
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
|
||||
"Tensor draft_token_ids, Tensor target_argmax, "
|
||||
"Tensor bonus_token_ids, Tensor? is_greedy, "
|
||||
"SymInt max_spec_len) -> ()",
|
||||
&cpu_utils::rejection_greedy_sample_kernel_impl);
|
||||
ops.def(
|
||||
"rejection_random_sample_kernel_impl("
|
||||
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
|
||||
"Tensor draft_token_ids, Tensor? draft_probs, "
|
||||
"Tensor target_probs, Tensor bonus_token_ids, "
|
||||
"Tensor recovered_token_ids, Tensor uniform_probs, "
|
||||
"Tensor? is_greedy, SymInt max_spec_len, SymInt vocab_size, "
|
||||
"bool no_draft_probs) -> ()",
|
||||
&cpu_utils::rejection_random_sample_kernel_impl);
|
||||
ops.def(
|
||||
"expand_kernel_impl(Tensor(a0!) output, Tensor input, "
|
||||
"Tensor cu_num_tokens, SymInt replace_from, "
|
||||
"SymInt replace_to) -> ()",
|
||||
&cpu_utils::expand_kernel_impl);
|
||||
ops.def(
|
||||
"sample_recovered_tokens_kernel_impl("
|
||||
"Tensor(a0!) output_token_ids, Tensor cu_num_draft_tokens, "
|
||||
"Tensor draft_token_ids, Tensor? draft_probs, "
|
||||
"Tensor target_probs, Tensor inv_q, SymInt vocab_size, "
|
||||
"bool no_draft_probs) -> ()",
|
||||
&cpu_utils::sample_recovered_tokens_kernel_impl);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@@ -45,3 +45,277 @@ def _compute_slot_mapping_kernel_impl(
|
||||
|
||||
|
||||
compute_slot_mapping_kernel = _FuncWrapper(_compute_slot_mapping_kernel_impl)
|
||||
|
||||
|
||||
def _ensure_int64(t: torch.Tensor) -> torch.Tensor:
|
||||
return t if t.dtype == torch.int64 else t.to(torch.int64)
|
||||
|
||||
|
||||
def _eagle_prepare_inputs_padded_kernel_impl(
|
||||
cu_num_draft_tokens,
|
||||
valid_sampled_tokens_count,
|
||||
query_start_loc_gpu,
|
||||
token_indices_to_sample,
|
||||
num_rejected_tokens_gpu,
|
||||
num_reqs,
|
||||
):
|
||||
# C++ expects int64 for cu_num_draft_tokens, valid_sampled_tokens_count,
|
||||
# and num_rejected_tokens_gpu, but Python allocates them as int32.
|
||||
orig_rejected_dtype = num_rejected_tokens_gpu.dtype
|
||||
rejected_i64 = (
|
||||
num_rejected_tokens_gpu
|
||||
if orig_rejected_dtype == torch.int64
|
||||
else num_rejected_tokens_gpu.to(torch.int64)
|
||||
)
|
||||
torch.ops._C.eagle_prepare_inputs_padded_kernel_impl(
|
||||
_ensure_int64(cu_num_draft_tokens),
|
||||
_ensure_int64(valid_sampled_tokens_count),
|
||||
query_start_loc_gpu,
|
||||
token_indices_to_sample,
|
||||
rejected_i64,
|
||||
num_reqs,
|
||||
)
|
||||
if orig_rejected_dtype != torch.int64:
|
||||
num_rejected_tokens_gpu.copy_(rejected_i64.to(orig_rejected_dtype))
|
||||
|
||||
|
||||
def _eagle_prepare_next_token_padded_kernel_impl(
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
backup_next_token_ids,
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
vocab_size,
|
||||
num_sampled_tokens_per_req,
|
||||
num_reqs,
|
||||
stride=None,
|
||||
BLOCK_SIZE_TOKENS=None,
|
||||
):
|
||||
# C++ reads all integer tensors as int64_t*. Output tensors are written
|
||||
# in-place so we create int64 copies, call C++, and copy back.
|
||||
orig_next_dtype = next_token_ids.dtype
|
||||
orig_valid_dtype = valid_sampled_tokens_count.dtype
|
||||
next_i64 = _ensure_int64(next_token_ids)
|
||||
valid_i64 = _ensure_int64(valid_sampled_tokens_count)
|
||||
torch.ops._C.eagle_prepare_next_token_padded_kernel_impl(
|
||||
_ensure_int64(sampled_token_ids),
|
||||
discard_request_mask,
|
||||
_ensure_int64(backup_next_token_ids),
|
||||
next_i64,
|
||||
valid_i64,
|
||||
vocab_size,
|
||||
num_sampled_tokens_per_req,
|
||||
num_reqs,
|
||||
)
|
||||
if orig_next_dtype != torch.int64:
|
||||
next_token_ids.copy_(next_i64.to(orig_next_dtype))
|
||||
if orig_valid_dtype != torch.int64:
|
||||
valid_sampled_tokens_count.copy_(valid_i64.to(orig_valid_dtype))
|
||||
|
||||
|
||||
def _eagle_step_slot_mapping_metadata_kernel_impl(
|
||||
positions,
|
||||
block_table,
|
||||
stride,
|
||||
seq_lens,
|
||||
out_clamped_positions,
|
||||
out_slot_mapping,
|
||||
block_size,
|
||||
max_model_len,
|
||||
n_blocks_per_req,
|
||||
PAD_ID,
|
||||
batch_size=None,
|
||||
):
|
||||
assert batch_size is None or batch_size == positions.shape[0], (
|
||||
f"batch_size mismatch: {batch_size} vs positions.shape[0]={positions.shape[0]}"
|
||||
)
|
||||
torch.ops._C.eagle_step_slot_mapping_metadata_kernel_impl(
|
||||
positions,
|
||||
block_table,
|
||||
seq_lens,
|
||||
out_clamped_positions,
|
||||
out_slot_mapping,
|
||||
block_size,
|
||||
max_model_len,
|
||||
PAD_ID,
|
||||
)
|
||||
|
||||
|
||||
def _copy_and_expand_eagle_inputs_kernel_impl(
|
||||
target_token_ids_ptr,
|
||||
target_positions_ptr,
|
||||
next_token_ids_ptr,
|
||||
out_input_ids_ptr,
|
||||
out_positions_ptr,
|
||||
out_is_rejected_token_mask_ptr,
|
||||
out_is_masked_token_mask_ptr,
|
||||
out_new_token_indices_ptr,
|
||||
out_hidden_state_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
query_end_loc_ptr,
|
||||
padding_token_id,
|
||||
parallel_drafting_token_id,
|
||||
total_input_tokens,
|
||||
num_padding_slots_per_request,
|
||||
shift_input_ids,
|
||||
BLOCK_SIZE_TOKENS=None,
|
||||
BLOCK_SIZE_REQS=None,
|
||||
):
|
||||
"""Adapter between Triton kernel call convention and C++ implementation.
|
||||
|
||||
The Triton kernel uses '_ptr' suffixed parameter names and compile-time
|
||||
constants (BLOCK_SIZE_TOKENS, BLOCK_SIZE_REQS) which are not needed by
|
||||
the C++ implementation. C++ reads token id tensors as int64_t*.
|
||||
Output tensors that are int32 need copy-back after C++ writes int64.
|
||||
"""
|
||||
orig_ids_dtype = out_input_ids_ptr.dtype
|
||||
orig_pos_dtype = out_positions_ptr.dtype
|
||||
out_ids_i64 = _ensure_int64(out_input_ids_ptr)
|
||||
out_pos_i64 = _ensure_int64(out_positions_ptr)
|
||||
torch.ops._C.copy_and_expand_eagle_inputs_kernel_impl(
|
||||
_ensure_int64(target_token_ids_ptr),
|
||||
_ensure_int64(target_positions_ptr),
|
||||
_ensure_int64(next_token_ids_ptr),
|
||||
out_ids_i64,
|
||||
out_pos_i64,
|
||||
out_is_rejected_token_mask_ptr,
|
||||
out_is_masked_token_mask_ptr,
|
||||
out_new_token_indices_ptr,
|
||||
out_hidden_state_mapping_ptr,
|
||||
query_start_loc_ptr,
|
||||
query_end_loc_ptr,
|
||||
padding_token_id,
|
||||
parallel_drafting_token_id,
|
||||
total_input_tokens,
|
||||
num_padding_slots_per_request,
|
||||
shift_input_ids,
|
||||
)
|
||||
if orig_ids_dtype != torch.int64:
|
||||
out_input_ids_ptr.copy_(out_ids_i64.to(orig_ids_dtype))
|
||||
if orig_pos_dtype != torch.int64:
|
||||
out_positions_ptr.copy_(out_pos_i64.to(orig_pos_dtype))
|
||||
|
||||
|
||||
def _rejection_greedy_sample_kernel_impl(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
):
|
||||
# C++ kernel expects int64 for all integer tensors.
|
||||
orig_dtype = output_token_ids.dtype
|
||||
output_token_ids_i64 = _ensure_int64(output_token_ids)
|
||||
torch.ops._C.rejection_greedy_sample_kernel_impl(
|
||||
output_token_ids_i64,
|
||||
_ensure_int64(cu_num_draft_tokens),
|
||||
_ensure_int64(draft_token_ids),
|
||||
_ensure_int64(target_argmax),
|
||||
_ensure_int64(bonus_token_ids),
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
)
|
||||
if orig_dtype != torch.int64:
|
||||
output_token_ids.copy_(output_token_ids_i64.to(orig_dtype))
|
||||
|
||||
|
||||
def _rejection_random_sample_kernel_impl(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS=False,
|
||||
):
|
||||
# C++ kernel expects int64 for all integer tensors and float32 for probs.
|
||||
# uniform_probs is intentionally float64 in Python to avoid exact-zero
|
||||
# samples; cast to float32 here for C++ compatibility.
|
||||
orig_dtype = output_token_ids.dtype
|
||||
output_token_ids_i64 = _ensure_int64(output_token_ids)
|
||||
torch.ops._C.rejection_random_sample_kernel_impl(
|
||||
output_token_ids_i64,
|
||||
_ensure_int64(cu_num_draft_tokens),
|
||||
_ensure_int64(draft_token_ids),
|
||||
draft_probs,
|
||||
target_probs,
|
||||
_ensure_int64(bonus_token_ids),
|
||||
_ensure_int64(recovered_token_ids),
|
||||
uniform_probs.to(torch.float32),
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS,
|
||||
)
|
||||
if orig_dtype != torch.int64:
|
||||
output_token_ids.copy_(output_token_ids_i64.to(orig_dtype))
|
||||
|
||||
|
||||
def _expand_kernel_impl(
|
||||
output,
|
||||
input_val,
|
||||
cu_num_tokens,
|
||||
replace_from,
|
||||
replace_to,
|
||||
MAX_NUM_TOKENS=None,
|
||||
):
|
||||
torch.ops._C.expand_kernel_impl(
|
||||
_ensure_int64(output),
|
||||
_ensure_int64(input_val),
|
||||
_ensure_int64(cu_num_tokens),
|
||||
replace_from,
|
||||
replace_to,
|
||||
)
|
||||
|
||||
|
||||
def _sample_recovered_tokens_kernel_impl(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
inv_q,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=None,
|
||||
NO_DRAFT_PROBS=False,
|
||||
):
|
||||
# C++ reads integer tensors as int64_t*; ensure correct dtype.
|
||||
orig_dtype = output_token_ids.dtype
|
||||
output_i64 = _ensure_int64(output_token_ids)
|
||||
torch.ops._C.sample_recovered_tokens_kernel_impl(
|
||||
output_i64,
|
||||
_ensure_int64(cu_num_draft_tokens),
|
||||
_ensure_int64(draft_token_ids),
|
||||
draft_probs,
|
||||
target_probs,
|
||||
inv_q,
|
||||
vocab_size,
|
||||
NO_DRAFT_PROBS,
|
||||
)
|
||||
if orig_dtype != torch.int64:
|
||||
output_token_ids.copy_(output_i64.to(orig_dtype))
|
||||
|
||||
|
||||
eagle_prepare_inputs_padded_kernel = _FuncWrapper(
|
||||
_eagle_prepare_inputs_padded_kernel_impl
|
||||
)
|
||||
eagle_prepare_next_token_padded_kernel = _FuncWrapper(
|
||||
_eagle_prepare_next_token_padded_kernel_impl
|
||||
)
|
||||
copy_and_expand_eagle_inputs_kernel = _FuncWrapper(
|
||||
_copy_and_expand_eagle_inputs_kernel_impl
|
||||
)
|
||||
eagle_step_slot_mapping_metadata_kernel = _FuncWrapper(
|
||||
_eagle_step_slot_mapping_metadata_kernel_impl
|
||||
)
|
||||
rejection_greedy_sample_kernel = _FuncWrapper(_rejection_greedy_sample_kernel_impl)
|
||||
rejection_random_sample_kernel = _FuncWrapper(_rejection_random_sample_kernel_impl)
|
||||
expand_kernel = _FuncWrapper(_expand_kernel_impl)
|
||||
sample_recovered_tokens_kernel = _FuncWrapper(_sample_recovered_tokens_kernel_impl)
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -48,6 +47,7 @@ from vllm.v1.spec_decode.utils import (
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
eagle_step_update_slot_mapping_and_metadata,
|
||||
extend_all_queries_by_N,
|
||||
next_power_of_2,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
@@ -689,9 +689,7 @@ class SpecDecodeBaseProposer:
|
||||
max_num_tokens_per_request = (
|
||||
cad.max_query_len + self.net_num_new_slots_per_request
|
||||
)
|
||||
BLOCK_SIZE_TOKENS = min(
|
||||
256, triton.next_power_of_2(max_num_tokens_per_request)
|
||||
)
|
||||
BLOCK_SIZE_TOKENS = min(256, next_power_of_2(max_num_tokens_per_request))
|
||||
num_blocks = (
|
||||
max_num_tokens_per_request + BLOCK_SIZE_TOKENS - 1
|
||||
) // BLOCK_SIZE_TOKENS
|
||||
@@ -717,6 +715,7 @@ class SpecDecodeBaseProposer:
|
||||
query_end_loc = cad.query_start_loc[1:] - 1
|
||||
if num_rejected_tokens_gpu is not None:
|
||||
query_end_loc = query_end_loc - num_rejected_tokens_gpu
|
||||
|
||||
copy_and_expand_eagle_inputs_kernel[grid](
|
||||
# (Padded) Inputs from the target model
|
||||
target_token_ids_ptr=target_token_ids,
|
||||
@@ -899,7 +898,7 @@ class SpecDecodeBaseProposer:
|
||||
grid = (batch_size,)
|
||||
|
||||
# Find the next power of 2 for block sizes
|
||||
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
|
||||
BLOCK_SIZE_TOKENS = next_power_of_2(num_tokens)
|
||||
eagle_prepare_next_token_padded_kernel[grid](
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
|
||||
@@ -11,6 +11,20 @@ from vllm.v1.attention.backends.utils import (
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
|
||||
def next_power_of_2(n: int) -> int:
|
||||
"""Return the smallest power of 2 >= n."""
|
||||
if n <= 0:
|
||||
return 1
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n |= n >> 32
|
||||
return n + 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_step_slot_mapping_metadata_kernel(
|
||||
positions_ptr, # [batch_size] - current positions (1D view for M-RoPE)
|
||||
@@ -102,8 +116,8 @@ def eagle_step_update_slot_mapping_and_metadata(
|
||||
batch_size = positions_1d.shape[0]
|
||||
if input_batch_size is None:
|
||||
input_batch_size = batch_size
|
||||
n_blocks_per_req = block_table_tensor.shape[1]
|
||||
|
||||
n_blocks_per_req = block_table_tensor.shape[1]
|
||||
eagle_step_slot_mapping_metadata_kernel[(input_batch_size,)](
|
||||
positions_1d,
|
||||
block_table_tensor,
|
||||
|
||||
@@ -11,6 +11,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.tracing import instrument
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
@@ -23,7 +25,7 @@ class CPUModelRunner(GPUModelRunner):
|
||||
super().__init__(vllm_config, device)
|
||||
|
||||
assert device == torch.device("cpu")
|
||||
assert self.speculative_config is None, "spec decode is not supported."
|
||||
# Note: speculative decoding is now supported on CPU with C++ native impls
|
||||
|
||||
self.use_cuda_graph = False
|
||||
self.cascade_attn_enabled = False
|
||||
@@ -61,6 +63,34 @@ class CPUModelRunner(GPUModelRunner):
|
||||
cpu_tl.compute_slot_mapping_kernel
|
||||
)
|
||||
|
||||
# Speculative decoding fallbacks
|
||||
import vllm.v1.sample.rejection_sampler
|
||||
import vllm.v1.spec_decode.eagle
|
||||
import vllm.v1.spec_decode.utils
|
||||
|
||||
vllm.v1.spec_decode.eagle.eagle_prepare_inputs_padded_kernel = (
|
||||
cpu_tl.eagle_prepare_inputs_padded_kernel
|
||||
)
|
||||
vllm.v1.spec_decode.eagle.eagle_prepare_next_token_padded_kernel = (
|
||||
cpu_tl.eagle_prepare_next_token_padded_kernel
|
||||
)
|
||||
vllm.v1.spec_decode.eagle.copy_and_expand_eagle_inputs_kernel = (
|
||||
cpu_tl.copy_and_expand_eagle_inputs_kernel
|
||||
)
|
||||
vllm.v1.spec_decode.utils.eagle_step_slot_mapping_metadata_kernel = (
|
||||
cpu_tl.eagle_step_slot_mapping_metadata_kernel
|
||||
)
|
||||
vllm.v1.sample.rejection_sampler.rejection_greedy_sample_kernel = (
|
||||
cpu_tl.rejection_greedy_sample_kernel
|
||||
)
|
||||
vllm.v1.sample.rejection_sampler.rejection_random_sample_kernel = (
|
||||
cpu_tl.rejection_random_sample_kernel
|
||||
)
|
||||
vllm.v1.sample.rejection_sampler.expand_kernel = cpu_tl.expand_kernel
|
||||
vllm.v1.sample.rejection_sampler.sample_recovered_tokens_kernel = (
|
||||
cpu_tl.sample_recovered_tokens_kernel
|
||||
)
|
||||
|
||||
@instrument(span_name="Loading (CPU)")
|
||||
def load_model(self, load_dummy_weights: bool = False) -> None:
|
||||
if load_dummy_weights:
|
||||
@@ -74,6 +104,10 @@ class CPUModelRunner(GPUModelRunner):
|
||||
if self.lora_config:
|
||||
self.model = self.load_lora_model(self.model, self.vllm_config, self.device)
|
||||
|
||||
if hasattr(self, "drafter"):
|
||||
logger.info_once("Loading drafter model...")
|
||||
self.drafter.load_model(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -89,8 +123,29 @@ class CPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
)
|
||||
|
||||
# Warm up drafter for speculative decoding
|
||||
if self.speculative_config and (self.speculative_config.uses_draft_model()):
|
||||
from vllm.v1.spec_decode.draft_model import DraftModelProposer
|
||||
|
||||
if isinstance(self.drafter, (DraftModelProposer)):
|
||||
logger.info("Warming up drafter model...")
|
||||
self.drafter.dummy_run(max(16, self.max_num_reqs))
|
||||
|
||||
logger.info("Warming up done.")
|
||||
|
||||
def initialize_kv_cache(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
is_profiling: bool = False,
|
||||
) -> None:
|
||||
super().initialize_kv_cache(kv_cache_config, is_profiling)
|
||||
|
||||
if self.speculative_config:
|
||||
if self.speculative_config.use_eagle():
|
||||
logger.info("EAGLE drafter KV cache initialized for CPU backend")
|
||||
elif self.speculative_config.uses_draft_model():
|
||||
logger.info("Draft model KV cache initialized for CPU backend")
|
||||
|
||||
def _init_device_properties(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -102,6 +157,71 @@ class CPUModelRunner(GPUModelRunner):
|
||||
# so stale KV cache data never affects computation.
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# CPU-safe overrides for speculative decoding methods
|
||||
# These methods override GPU-specific implementations that use CUDA streams
|
||||
# =========================================================================
|
||||
|
||||
def _copy_draft_token_ids_to_cpu(
|
||||
self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
|
||||
) -> None:
|
||||
"""CPU-safe version: no async copy needed, tensors already on CPU."""
|
||||
if self.use_async_scheduling and not (
|
||||
scheduler_output.has_structured_output_requests
|
||||
or self.input_batch.sampling_metadata.output_token_ids
|
||||
):
|
||||
return
|
||||
self._draft_token_req_ids = self.input_batch.req_ids.copy()
|
||||
|
||||
draft_token_ids: torch.Tensor = self._draft_token_ids
|
||||
if not torch.is_tensor(draft_token_ids):
|
||||
return
|
||||
|
||||
num_reqs = draft_token_ids.shape[0]
|
||||
if self.draft_token_ids_cpu is not None:
|
||||
if not zeros_only:
|
||||
self.draft_token_ids_cpu[:num_reqs].copy_(draft_token_ids)
|
||||
else:
|
||||
self.draft_token_ids_cpu[:num_reqs] = 0
|
||||
|
||||
def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
|
||||
"""CPU-safe version: no event synchronization needed."""
|
||||
if isinstance(self._draft_token_ids, list):
|
||||
return self._draft_token_ids, self.input_batch.req_ids
|
||||
req_ids = self._draft_token_req_ids
|
||||
if req_ids is None:
|
||||
return [], []
|
||||
if self.draft_token_ids_cpu is not None:
|
||||
return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
|
||||
return [], []
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
|
||||
) -> None:
|
||||
"""CPU-safe version: direct copy without CUDA streams."""
|
||||
if self.valid_sampled_token_count_cpu is None:
|
||||
return
|
||||
|
||||
counts = valid_sampled_tokens_count
|
||||
counts_cpu = self.valid_sampled_token_count_cpu
|
||||
counts_cpu[: counts.shape[0]].copy_(counts)
|
||||
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
|
||||
|
||||
def _get_valid_sampled_token_count(self) -> list[int]:
|
||||
"""CPU-safe version: no event synchronization needed."""
|
||||
prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
|
||||
if prev_sampled_token_ids is None:
|
||||
return []
|
||||
|
||||
counts_cpu = self.valid_sampled_token_count_cpu
|
||||
if counts_cpu is None:
|
||||
return []
|
||||
return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
"""CPU-safe version: direct tolist() without CUDA events."""
|
||||
return sampled_token_ids.tolist()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _torch_cuda_wrapper():
|
||||
|
||||
Reference in New Issue
Block a user