mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
614 lines
28 KiB
Python
614 lines
28 KiB
Python
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
|
|
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs.py
|
|
# Triton is licensed under the MIT License.
|
|
|
|
# isort: off
|
|
# fmt: off
|
|
from dataclasses import dataclass
|
|
import itertools
|
|
import sys
|
|
import torch
|
|
import triton
|
|
from enum import Enum, auto
|
|
import math
|
|
# utilities
|
|
from triton_kernels import target_info
|
|
from triton_kernels.numerics import InFlexData, OutFlexData
|
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
|
from triton_kernels.target_info import is_cuda
|
|
# details
|
|
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
|
|
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
|
|
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
|
|
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
|
|
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
|
|
from .specialize import specialize
|
|
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FnSpecs:
|
|
name: str
|
|
fn: "triton.runtime.jit.JITFunction"
|
|
fn_arg_names: tuple[str]
|
|
fn_arg_do_not_specialize: tuple[str] = tuple()
|
|
|
|
@staticmethod
|
|
def default():
|
|
return FnSpecs("dflt", None, tuple())
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FusedActivation:
|
|
specs: FnSpecs = FnSpecs.default()
|
|
fn_args: tuple[object] = tuple()
|
|
reduction_n: int = 1
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Epilogue:
|
|
specs: FnSpecs = FnSpecs.default()
|
|
fn_arg_values_matmul: tuple[object] = tuple()
|
|
fn_arg_values_finalize: tuple[object] = tuple()
|
|
effective_itemsize: float = None
|
|
|
|
class FnName(Enum):
|
|
QUANTIZE_MXFP8 = auto()
|
|
|
|
|
|
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
|
|
|
|
_kernels = dict()
|
|
|
|
|
|
def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
|
|
global _kernels
|
|
key = (fused_activation.name, epilogue.name)
|
|
if key in _kernels:
|
|
return _kernels[key]
|
|
spec_constants = {
|
|
"ACTIVATION_FN": fused_activation.fn,
|
|
"EPILOGUE_FN": epilogue.fn,
|
|
}
|
|
spec_tuples = {
|
|
"activation_fn_args": fused_activation.fn_arg_names,
|
|
"epilogue_fn_args": epilogue.fn_arg_names,
|
|
}
|
|
do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
|
|
import types
|
|
|
|
module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
|
|
sys.modules[module.__name__] = module
|
|
module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
|
|
do_not_specialize=do_not_specialize)
|
|
module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
|
|
do_not_specialize=do_not_specialize)
|
|
module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
|
|
do_not_specialize=do_not_specialize)
|
|
_kernels[key] = module
|
|
return module
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Matrix Multiplication + Outer Gather/Scatter
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def can_overflow_int32(tensor: torch.Tensor):
|
|
max_int32 = (1 << 31) - 1
|
|
offset = 0
|
|
for i in range(tensor.ndim):
|
|
offset += (tensor.shape[i] - 1) * tensor.stride(i)
|
|
return offset > max_int32
|
|
|
|
|
|
def should_upcast_indices(*args):
|
|
return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
|
|
|
|
|
|
# ---------------------
|
|
# Numerics
|
|
# ---------------------
|
|
|
|
# fmt: off
|
|
|
|
@dataclass(frozen=True)
|
|
class FlexCtx:
|
|
lhs_data: InFlexData = InFlexData()
|
|
rhs_data: InFlexData = InFlexData()
|
|
out_data: OutFlexData = OutFlexData()
|
|
|
|
@dataclass
|
|
class PrecisionConfig:
|
|
max_num_imprecise_acc: int = None
|
|
allow_tf32: bool = True
|
|
flex_ctx: FlexCtx = FlexCtx()
|
|
acc_scale: int = 1.0
|
|
flexpoint_saturate_inf: bool = False
|
|
report_quantization_err_fn: callable = None
|
|
act_scale: Tensor | None = None
|
|
weight_scale: Tensor| None = None
|
|
out_scale: Tensor | None = None
|
|
out_dtype: torch.dtype = None
|
|
enforce_bitwise_invariance: bool = False
|
|
|
|
|
|
# TODO: merge in opt_flags
|
|
def get_swap_xw(precision_config, opt_flags):
|
|
if target_info.cuda_capability_geq(10, 0):
|
|
return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
|
|
return False
|
|
|
|
# ---------------------
|
|
# Allocation
|
|
# ---------------------
|
|
|
|
@dataclass
|
|
class MatmulAllocation:
|
|
device: str
|
|
output: tuple[tuple[int], torch.dtype]
|
|
scratchpads: dict[str, tuple]
|
|
|
|
def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
|
|
# ---- output ------
|
|
N = w.shape[-1]
|
|
# by default - M is number of rows in the activations
|
|
M = x.shape[-2]
|
|
# if the activations are gathered, then M is number of gather indices
|
|
if gather_indx is not None:
|
|
M = gather_indx.src_indx.shape[0]
|
|
# final output
|
|
if routing_data.n_expts_act == 1 or scatter_indx is None:
|
|
y_rows = M
|
|
else:
|
|
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
|
|
y_rows = Mc
|
|
batch_dim = x.shape[0] if x.ndim == 3 else 1
|
|
out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
|
|
out_dtype = precision_config.out_dtype or x.dtype
|
|
output = (out_shape, out_dtype)
|
|
# ---- scratchpad -----#
|
|
scratchpad = dict()
|
|
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
|
|
scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
|
|
scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
|
|
if "matmul" in scratchpad and precision_config.out_scale is not None:
|
|
scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
|
|
return MatmulAllocation(x.device, output, scratchpad)
|
|
|
|
def apply_allocation(allocation: MatmulAllocation, output):
|
|
ret = dict()
|
|
if output is None:
|
|
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
|
|
else:
|
|
assert output.shape == allocation.output[0]
|
|
ret["output"] = output[None, :, :]
|
|
ret["scratchpad"] = {
|
|
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
|
|
for k, v in allocation.scratchpads.items()
|
|
}
|
|
return ret
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Canonicalize
|
|
# -----------------------------------------------------------------------------
|
|
# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
|
|
# we can canonicalize storages to make the implementation more uniform
|
|
|
|
def _canonicalize_storage(storage, out_ndim, flex_data):
|
|
assert out_ndim >= storage.data.ndim
|
|
# Need to use as_strided instead of view because for a tensor with
|
|
# shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
|
|
# > t = torch.randn(2, 5, 1).mT
|
|
# > t_view = t.view(t.shape)
|
|
# > t.stride(), t_view.stride()
|
|
# ((5, 1, 1), (5, 5, 1))
|
|
# Our check t_view is col-wise fails since t_view.stride(-2) != 1
|
|
# This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
|
|
new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
|
|
new_storage_view = storage.data.view(new_storage_shape)
|
|
new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
|
|
new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
|
|
if flex_data is not None:
|
|
new_storage_data = flex_data.reinterpret(new_storage_data)
|
|
return Storage(new_storage_data, storage.layout)
|
|
|
|
#
|
|
|
|
def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
|
|
fused_activation, epilogue,
|
|
x_flex: InFlexData | None = None,
|
|
out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
|
|
out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
|
|
"""
|
|
In-place grouped row reduction.
|
|
|
|
Arguments
|
|
- x: Tensor[AnyFloat] of shape [(num_groups * K), N]
|
|
- indx: Tensor[Int] of shape [num_groups, K]
|
|
|
|
Description
|
|
For each group g in [0, num_groups), this routine sums the K rows of `x`
|
|
specified by `indx[g, :]` and overwrites the row corresponding to the first
|
|
valid (non-negative) index with the per-group sum. Accumulation is performed
|
|
in float32 for numerical stability, and the result is written back in the
|
|
dtype of `x`.
|
|
|
|
Behavior and edge cases
|
|
- Invalid (-1) entries are skipped during accumulation and do not generate
|
|
memory traffic. If a group has no valid entries, nothing is written for
|
|
that group.
|
|
- Reduction is performed tile-by-tile along the N dimension within a single
|
|
kernel launch (persistent along N) to minimize launch overhead.
|
|
|
|
Performance notes
|
|
- Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
|
|
plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
|
|
of length N per group.
|
|
|
|
Returns
|
|
- The input tensor `x` (modified in place).
|
|
"""
|
|
if indx is None and x.shape[0] == 1:
|
|
return x.squeeze(0), None
|
|
if indx is not None:
|
|
num_groups = indx.shape[0]
|
|
else:
|
|
num_groups = x.shape[-2]
|
|
if x_flex is None:
|
|
x_flex = InFlexData()
|
|
if out_flex is None:
|
|
out_flex = OutFlexData()
|
|
K = 1 if indx is None else indx.shape[1]
|
|
out_dtype = x.dtype if out_dtype is None else out_dtype
|
|
assert x.shape[-1] % fused_activation.reduction_n == 0
|
|
BLOCK_N = 512
|
|
# Resolve scalar flex scales (may be None)
|
|
x_expected_scale = None if x_flex is None else x_flex.scale
|
|
out_expected_scale = None if out_flex is None else out_flex.expected_scale
|
|
out_actual_scale = None if out_flex is None else out_flex.actual_scale
|
|
out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
|
|
# Resolve MXFP output scale row stride
|
|
stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
|
|
stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
|
|
stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
|
|
kernels = get_kernels(epilogue.specs, fused_activation.specs)
|
|
kernels._reduce_grouped[(num_groups, )](
|
|
x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
|
|
x_expected_scale, # scalar input scale
|
|
out_flex.reinterpret(out), out.stride(1), out.stride(2), #
|
|
out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
|
|
x.shape[0], x.shape[-1], #
|
|
x_mx_scale, stride_mxb, stride_mxs, #
|
|
out_mx_scale, stride_omxs, #
|
|
*fused_activation.fn_args, fused_activation.reduction_n,
|
|
*epilogue.fn_arg_values_finalize,
|
|
HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
|
|
FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
|
|
BLOCK_N=BLOCK_N, K=K, #
|
|
num_warps=1, #
|
|
)
|
|
return out, out_mx_scale
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Triton Implementation
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def matmul_ogs_set_idle_sms(num_idle_sms):
|
|
"""
|
|
persistent kernels will leave `num_idle_sms` idle
|
|
"""
|
|
update_opt_flags_constraints({"idle_sms": num_idle_sms})
|
|
|
|
def matmul_ogs(x, w, bias,
|
|
routing_data: RoutingData | None = None,
|
|
gather_indx: GatherIndx | None = None,
|
|
scatter_indx: ScatterIndx | None = None,
|
|
precision_config: PrecisionConfig | None = None,
|
|
betas: torch.Tensor | None = None,
|
|
gammas: torch.Tensor | None = None,
|
|
out_alpha: float | None = None,
|
|
y: torch.Tensor | None = None,
|
|
fused_activation: FusedActivation | None = None,
|
|
epilogue: Epilogue | None = None,
|
|
):
|
|
"""
|
|
Y[:, :] = 0.
|
|
for e in num_experts:
|
|
Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
|
|
"""
|
|
is_input_batched = x.ndim == 3
|
|
if is_input_batched:
|
|
assert gather_indx is None, "gather not supported in batched mode"
|
|
assert scatter_indx is None, "scatter not supported in batched mode"
|
|
assert routing_data is None, "routing not supported in batched mode"
|
|
assert w.ndim == 3 and w.shape[0] == x.shape[0]
|
|
# canonicalize inputs
|
|
if precision_config is None:
|
|
precision_config = PrecisionConfig()
|
|
if fused_activation is None:
|
|
fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
|
|
if epilogue is None:
|
|
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
|
|
if routing_data is None:
|
|
routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
|
|
# unpack scales
|
|
w_scale = precision_config.weight_scale
|
|
w_has_mx = w_scale is not None
|
|
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
|
|
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
|
|
if not isinstance(w, Tensor):
|
|
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
|
|
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
|
|
w = wrap_torch_tensor(w, dtype=dtype)
|
|
if w_scale is not None and not isinstance(w_scale, Tensor):
|
|
w_scale = Tensor(w_scale)
|
|
if w_scale is not None:
|
|
w_scale.storage.data = w_scale.data.view(torch.uint8)
|
|
w_scale.dtype = torch.uint8
|
|
x_scale = precision_config.act_scale
|
|
x_has_mx = x_scale is not None
|
|
if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
|
|
if x_scale is not None and not isinstance(x_scale, Tensor):
|
|
x_scale = Tensor(x_scale)
|
|
if not isinstance(x, Tensor):
|
|
x = Tensor(x, dtype=x.dtype)
|
|
# determine shapes
|
|
has_gather = gather_indx is not None
|
|
has_scatter = scatter_indx is not None
|
|
is_ragged = routing_data.expt_hist is not None
|
|
M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
|
|
batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
|
|
K, N = w.shape[-2:]
|
|
assert K == x.shape[-1]
|
|
if x.ndim == 3 and w.ndim == 3:
|
|
assert x.shape[0] == w.shape[0]
|
|
# compute optimization flags
|
|
out_dtype = precision_config.out_dtype or x.dtype
|
|
can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
|
|
w.numel() > 0 and w.storage.is_tma_compliant() and \
|
|
(w_scale is None or w_scale.storage.is_tma_compliant())
|
|
# hopper w/ mxfp4 doesn't support TMA
|
|
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
|
|
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
|
|
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
|
|
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
|
|
)
|
|
if not can_use_fused_scatter and opt_flags.fused_scatter:
|
|
raise InapplicableConstraint("Fused scatter is not supported")
|
|
if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
|
|
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
|
|
if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
|
|
raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
|
|
# fused activation
|
|
matmul_fused_activation = fused_activation
|
|
reduce_fused_activation = FusedActivation()
|
|
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
|
|
matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
|
|
# allocate output/scratchpad memory
|
|
allocation = init_allocation(x, w, precision_config, fused_activation,
|
|
routing_data, gather_indx, scatter_indx, opt_flags)
|
|
memory = apply_allocation(allocation, y)
|
|
# early exit
|
|
if batch_size * M * N == 0:
|
|
ret = memory["output"].squeeze(0)
|
|
if not is_input_batched:
|
|
ret = ret.squeeze(0)
|
|
return ret
|
|
# TMA descriptors require a global memory allocation
|
|
if opt_flags.is_persistent:
|
|
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
|
|
# Intermediate tensors and postprocess kernels for each situation
|
|
has_scratchpad = "matmul" in memory["scratchpad"]
|
|
# Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
|
|
out_matmul = memory["scratchpad"].get("matmul", memory["output"])
|
|
out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
|
|
# Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
|
|
out_matmul_scale = precision_config.out_scale
|
|
if out_matmul_scale is not None:
|
|
out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
|
|
if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
|
|
out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
|
|
out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
|
|
# matrix multiplication
|
|
flex = precision_config.flex_ctx
|
|
bias_stride = None if bias is None else bias.stride(0)
|
|
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
|
|
# moe metadata
|
|
expt_data = routing_data.expt_data
|
|
block_m = opt_flags.block_m
|
|
expt_hist = None if expt_data is None else expt_data.hist
|
|
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
|
|
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
|
|
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
|
|
# spmd grid
|
|
grid_m = triton.cdiv(M, opt_flags.block_m)
|
|
if expt_block_pid_map is not None:
|
|
grid_m = routing_data.n_blocks(M, opt_flags.block_m)
|
|
grid_n = triton.cdiv(N, opt_flags.block_n)
|
|
max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
|
|
grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
|
|
# canonicalize storage
|
|
has_gather_tma = has_gather and target_info.has_tma_gather()
|
|
has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
|
|
y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
|
|
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
|
|
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
|
|
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
|
|
# create tma descriptor for x
|
|
x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
|
|
x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
|
|
x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
|
|
x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
|
|
# create tma descriptor for y
|
|
y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
|
|
block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
|
|
y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
|
|
y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
|
|
y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
|
|
# create tma descriptor for w
|
|
w_has_tma = opt_flags.is_persistent
|
|
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
|
|
# create tma descriptor for w_scale
|
|
w_scale_tensor_or_tma = w_scale
|
|
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
|
|
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
|
|
# canonicalize strides
|
|
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
|
|
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
|
|
x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
|
|
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
|
|
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
|
|
out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
|
|
out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
|
|
# launch kernel
|
|
kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
|
|
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
|
|
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
|
|
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
|
|
# w_transpose = w_storage.data.stride()[-1] != 1
|
|
w_transpose = w_storage.data.stride()[-2] == 1
|
|
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
|
|
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
|
|
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
|
|
*out_matmul_scale_strides[-3:],
|
|
x_tensor_or_tma, x_storage.data, *x_strides,
|
|
flex.lhs_data.scale,
|
|
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
|
|
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
|
|
flex.rhs_data.scale,
|
|
w_scale_tensor_or_tma, *w_scale_strides,
|
|
bias, bias_stride,
|
|
x.shape[-2],
|
|
x.shape[-2] if routing_data.expt_hist is None else None,
|
|
N, K,
|
|
betas, gammas,
|
|
None if gather_indx is None else gather_indx.src_indx,
|
|
None if scatter_indx is None else scatter_indx.src_indx,
|
|
num_indx,
|
|
None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
|
|
None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
|
|
expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
|
|
batch_size, grid_m, grid_n,
|
|
out_alpha,
|
|
*matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
|
|
*epilogue.fn_arg_values_matmul,
|
|
routing_data.n_expts_tot, routing_data.n_expts_act,
|
|
precision_config.max_num_imprecise_acc,
|
|
precision_config.allow_tf32,
|
|
precision_config.flexpoint_saturate_inf,
|
|
flex.rhs_data.is_per_batch,
|
|
opt_flags.block_m,
|
|
opt_flags.block_n,
|
|
opt_flags.block_k,
|
|
opt_flags.group_m,
|
|
XCD_SWIZZLE=opt_flags.xcd_swizzle,
|
|
SWIZZLE_MX_VALUE=w.storage.layout.name,
|
|
SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
|
|
EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
|
|
SPLIT_K=opt_flags.split_k,
|
|
EVEN_K=K % opt_flags.block_k == 0,
|
|
W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
|
|
TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
|
|
num_warps=opt_flags.num_warps,
|
|
num_stages=opt_flags.num_stages,
|
|
arch=opt_flags.arch,
|
|
UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
|
|
X_TMA_MODE=x_tma_mode,
|
|
Y_TMA_MODE=y_tma_mode,
|
|
SWAP_XW=get_swap_xw(precision_config, opt_flags),
|
|
IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
|
|
NUM_SMS = grid if opt_flags.is_persistent else 0,
|
|
**opt_flags.target_kernel_kwargs)
|
|
# Build grouped reduction inputs in a uniform way
|
|
group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
|
|
out_final, out_final_mx_scale = reduce_grouped(
|
|
out_matmul,
|
|
group_indx,
|
|
memory["output"].squeeze(0),
|
|
precision_config.out_scale,
|
|
reduce_fused_activation,
|
|
epilogue,
|
|
x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
|
|
out_flex=precision_config.flex_ctx.out_data,
|
|
x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
|
|
out_dtype=memory["output"].dtype,
|
|
flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
|
|
)
|
|
if not is_input_batched:
|
|
out_final = out_final.squeeze(0)
|
|
if out_final_mx_scale is not None:
|
|
precision_config.out_scale = out_final_mx_scale
|
|
return out_final
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Reference Implementation
|
|
# -----------------------------------------------------------------------------
|
|
|
|
def matmul_ogs_torch(x, w, bias,
|
|
routing_data: RoutingData = None,
|
|
gather_indx: GatherIndx = None,
|
|
scatter_indx: ScatterIndx = None,
|
|
precision_config: PrecisionConfig = None,
|
|
betas = None,
|
|
gammas = None,
|
|
round_x = None, round_y = None,
|
|
):
|
|
is_input_batched = x.ndim == 3
|
|
assert x.dtype.itemsize > 1
|
|
assert w.dtype.itemsize > 1
|
|
if is_input_batched:
|
|
assert gather_indx is None, "gather not supported in batched mode"
|
|
assert scatter_indx is None, "scatter not supported in batched mode"
|
|
assert routing_data is None, "routing not supported in batched mode"
|
|
assert w.ndim == 3 and w.shape[0] == x.shape[0]
|
|
if round_x is None:
|
|
round_x = lambda x, idx: x
|
|
if round_y is None:
|
|
round_y = lambda x: x
|
|
if bias is not None and bias.ndim == 1:
|
|
bias = bias.view(1, *bias.shape)
|
|
if w.ndim == 2:
|
|
w = w.view(1, *w.shape)
|
|
if x.ndim == 2:
|
|
x = x.view(1, *x.shape)
|
|
if routing_data is None:
|
|
routing_data = RoutingData(None, None, w.shape[0], 1)
|
|
n_expts_act = routing_data.n_expts_act
|
|
# memory offsets
|
|
if routing_data.n_expts_tot > 1 and not is_input_batched:
|
|
sizes = routing_data.expt_hist
|
|
off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
|
|
off[1:] = torch.cumsum(sizes, 0)
|
|
offs = list(itertools.pairwise(off))
|
|
else:
|
|
offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
|
|
# compute
|
|
n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
|
|
y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
|
|
for i, (lo, hi) in enumerate(offs):
|
|
if gather_indx is None:
|
|
idx = torch.arange(lo, hi, device=x.device)
|
|
else:
|
|
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
|
batch = i if is_input_batched else 0
|
|
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
|
|
w[i].float())
|
|
if bias is not None:
|
|
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
|
if gammas is not None:
|
|
out *= gammas[lo:hi, None]
|
|
y[batch, lo:hi, :] = round_y(out)
|
|
if not is_input_batched:
|
|
y = y.view(y.shape[1], y.shape[2])
|
|
if scatter_indx is None:
|
|
return y
|
|
# accumulate output from all experts
|
|
n_rows = y.shape[0] // n_expts_act
|
|
out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
|
|
for i, (lo, hi) in enumerate(offs):
|
|
dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
|
|
msk = dst_idx != -1
|
|
out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
|
|
return out
|