diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index 2dc93d5a6c..5c918c25d5 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -308,8 +308,8 @@ public: std::vector output_shape = {num_rows, hidden_size}; auto output = torch::empty(output_shape, input.options().dtype(mOutputDtype)); - WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), activation_type, parallelism_config, min_latency_mode); + WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); kernels::MoeMinLatencyParams min_latency_params{}; @@ -439,8 +439,8 @@ public: min_latency_params.experts_to_token_score = static_cast(experts_to_token_score.data_ptr()); min_latency_params.active_expert_global_ids = static_cast(active_expert_global_ids.data_ptr()); - WorkspaceInfo workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), activation_type, parallelism_config, min_latency_mode); + WorkspaceInfo const& workspace_info = getWorkspaceInfo(num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), activation_type, parallelism_config, min_latency_mode, stream); auto const quant_params = getQuantParams(num_experts_on_rank, hidden_size, inter_size, quant_scales); @@ -577,6 +577,7 @@ private: // e.g. 16 nvfp4 elements are packed into a single int64 element int64_t mInnerDimMultiplier; char* mProfileWorkspace = nullptr; + WorkspaceInfo workspace_info; bool mUseDeepSeekFP8BlockScaling = false; bool mUseW4A8GroupScaling = false; @@ -622,9 +623,9 @@ private: mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile); } - WorkspaceInfo getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, + WorkspaceInfo const& getWorkspaceInfo(int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int num_experts, int experts_per_token, ActivationType activation_type, - kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode) + kernels::MOEParallelismConfig const& parallelismConfig, bool min_latency_mode, cudaStream_t stream) { size_t moe_workspace_size = mKernelRunner->getWorkspaceSize(num_rows, hidden_size, inter_size, num_experts, experts_per_token, activation_type, parallelismConfig, /* use_lora */ false, mUseDeepSeekFP8BlockScaling, @@ -633,15 +634,29 @@ private: std::vector workspaces{moe_workspace_size, src_to_dest_map_size}; - size_t total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); + int64_t const total_workspace_size = common::calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); - WorkspaceInfo info{}; - info.workspace = torch::empty({static_cast(total_workspace_size)}, - torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); - info.src_to_dest_map - = common::nextWorkspacePtr(static_cast(info.workspace.data_ptr()), moe_workspace_size); + bool is_capturing = tensorrt_llm::common::isCapturing(stream); + // Always allocate workspace when capturing cuda graph to avoid illegal memory access during replay + if (is_capturing || workspace_info.workspace.numel() < total_workspace_size) + { + if (is_capturing) + { + TLLM_LOG_DEBUG( + "Allocating MoE workspace with %ld bytes size during cuda graph capture", total_workspace_size); + } + else + { + TLLM_LOG_DEBUG("MoE workspace size is not enough, increase the size from %ld bytes to %ld bytes", + workspace_info.workspace.numel(), total_workspace_size); + } + workspace_info.workspace = torch::empty({static_cast(total_workspace_size)}, + torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + } + workspace_info.src_to_dest_map + = common::nextWorkspacePtr(static_cast(workspace_info.workspace.data_ptr()), moe_workspace_size); - return info; + return workspace_info; } kernels::QuantParams getQuantParams(int64_t const num_experts_on_rank, int64_t const hidden_size,