mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[Refactor] Cleanup batch invariant dead code (#41993)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
|
||||
@@ -8,15 +8,12 @@ from typing import Any
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.mem_utils import get_max_shared_memory_bytes
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _matmul_launch_metadata(
|
||||
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
||||
@@ -24,15 +21,8 @@ def _matmul_launch_metadata(
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = (
|
||||
f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]"
|
||||
)
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||
return ret
|
||||
@@ -191,7 +181,6 @@ def matmul_persistent(
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
# print(a.device, b.device, c.device)
|
||||
matmul_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
@@ -420,7 +409,7 @@ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
input: Input tensor
|
||||
dim: Dimension along which to compute log_softmax
|
||||
(only -1 or last dim supported)
|
||||
>> Stashed changes
|
||||
|
||||
Returns:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
@@ -910,18 +899,11 @@ def linear_batch_invariant(input, weight, bias=None):
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
_original_torch_bmm = None
|
||||
_original_fp16_reduction_precision = None
|
||||
_original_bf16_reduction_precision = None
|
||||
_original_cublas_workspace_cfg = None
|
||||
_original_cublaslt_workspace_size = None
|
||||
_fp16_block_size_n = 256
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
|
||||
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
global _fp16_block_size_n
|
||||
|
||||
if _batch_invariant_MODE:
|
||||
@@ -941,10 +923,6 @@ def enable_batch_invariant_mode():
|
||||
# Hopper (SM90) and Blackwell (SM100): the only source of batch
|
||||
# variance is split-k, which we disable via the cuBLAS workspace
|
||||
# config.
|
||||
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
@@ -966,16 +944,8 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::bmm", bmm_batch_invariant, "CUDA", allow_override=True
|
||||
)
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
_original_bf16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
|
||||
)
|
||||
_original_fp16_reduction_precision = (
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
|
||||
)
|
||||
|
||||
reduced_precision_val = (
|
||||
(False, False) if is_torch_equal_or_newer("2.10.0") else False
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user