[Refactor] Cleanup batch invariant dead code (#41993)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-05-11 10:48:39 -04:00
committed by GitHub
parent 5f1b313900
commit 4b64fc2cbf
2 changed files with 4 additions and 35 deletions
-1
View File
@@ -1,7 +1,6 @@
#pragma once
#include <cstdlib>
#include <string>
#include <cctype>
namespace vllm {
+4 -34
View File
@@ -8,15 +8,12 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.mem_utils import get_max_shared_memory_bytes
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
def _matmul_launch_metadata(
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
@@ -24,15 +21,8 @@ def _matmul_launch_metadata(
ret = {}
m, n, k = args["M"], args["N"], args["K"]
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
if "tiles_per_update" in args:
ret["name"] = (
f"{kernel.name} [M={m}, N={n}, K={k}, "
f"tiles_per_update={args['tiles_per_update']:02}]"
)
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
bytes_per_elem = args["c_ptr"].element_size()
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
return ret
@@ -191,7 +181,6 @@ def matmul_persistent(
"num_warps": 8,
},
}
# print(a.device, b.device, c.device)
matmul_kernel_persistent[grid](
a,
b,
@@ -420,7 +409,7 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
input: Input tensor
dim: Dimension along which to compute log_softmax
(only -1 or last dim supported)
>> Stashed changes
Returns:
Tensor with log_softmax applied along the specified dimension
"""
@@ -910,18 +899,11 @@ def linear_batch_invariant(input, weight, bias=None):
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
_original_fp16_reduction_precision = None
_original_bf16_reduction_precision = None
_original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
_fp16_block_size_n = 256
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
global _batch_invariant_MODE, _batch_invariant_LIB
global _fp16_block_size_n
if _batch_invariant_MODE:
@@ -941,10 +923,6 @@ def enable_batch_invariant_mode():
# Hopper (SM90) and Blackwell (SM100): the only source of batch
# variance is split-k, which we disable via the cuBLAS workspace
# config.
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
@@ -966,16 +944,8 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl(
"aten::bmm", bmm_batch_invariant, "CUDA", allow_override=True
)
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
_original_bf16_reduction_precision = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
)
_original_fp16_reduction_precision = (
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
)
reduced_precision_val = (
(False, False) if is_torch_equal_or_newer("2.10.0") else False
)