mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[fix] Fix MoE workspace info by storing Torch tensor itself instead of data_ptr (#5900)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
parent
3aa53ec36c
commit
8b9a030a5c
@ -327,7 +327,7 @@ public:
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
#else
|
||||
@ -457,7 +457,7 @@ public:
|
||||
fc2_expert_weights.const_data_ptr(),
|
||||
fc2_expert_biases.has_value() ? fc2_expert_biases.value().const_data_ptr() : nullptr, quant_params,
|
||||
num_rows, hidden_size, inter_size, num_experts_total, static_cast<int>(experts_per_token),
|
||||
static_cast<char*>(workspace_info.workspace), output.data_ptr(),
|
||||
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
|
||||
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params,
|
||||
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
|
||||
#else
|
||||
@ -563,7 +563,7 @@ public:
|
||||
private:
|
||||
struct WorkspaceInfo
|
||||
{
|
||||
void* workspace{};
|
||||
torch::Tensor workspace{};
|
||||
void* src_to_dest_map{};
|
||||
};
|
||||
|
||||
@ -634,12 +634,12 @@ private:
|
||||
std::vector<size_t> workspaces{moe_workspace_size, src_to_dest_map_size};
|
||||
|
||||
size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size());
|
||||
auto workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
|
||||
WorkspaceInfo info{};
|
||||
info.workspace = workspace.data_ptr();
|
||||
info.src_to_dest_map = common::nextWorkspacePtr(static_cast<int8_t*>(workspace.data_ptr()), moe_workspace_size);
|
||||
info.workspace = torch::empty({static_cast<long>(total_workspace_size)},
|
||||
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
|
||||
info.src_to_dest_map
|
||||
= common::nextWorkspacePtr(static_cast<int8_t*>(info.workspace.data_ptr()), moe_workspace_size);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user