[fix] Fix MoE workspace info by storing Torch tensor itself instead of data_ptr (#5900)

Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
Jinyang Yuan 2025-07-10 19:07:32 +08:00 committed by GitHub
parent 3aa53ec36c
commit 8b9a030a5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -327,7 +327,7 @@ public:
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
@ -457,7 +457,7 @@ public:
fc2_expert_weights.const_data_ptr(),
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
@ -563,7 +563,7 @@ public:
private:
struct WorkspaceInfo
{
void* workspace{};
torch::Tensor workspace{};
void* src_to_dest_map{};
};
@ -634,12 +634,12 @@ private:
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
auto workspace = torch::empty({static_cast<long>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
WorkspaceInfo info{};
info.workspace = workspace.data_ptr();
info.src_to_dest_map = common::nextWorkspacePtr(static_cast<int8_t*>(workspace.data_ptr()), moe_workspace_size);
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
info.src_to_dest_map
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
return info;
}