[https://nvbugs/5761391][fix] Include triton-kernels as a packaged dependency (#10471)

Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
Anish Shanbhag 2026-01-28 19:56:32 -08:00 committed by GitHub
parent e20f9a9c72
commit 24ac86c485
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 6312 additions and 248 deletions

6
.github/CODEOWNERS vendored
View File

@ -217,6 +217,12 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers
## Any changes to versions, additions, or removals of third-party libraries
/3rdparty/** @NVIDIA/trt-llm-oss-compliance
### Vendored Third-Party Code (triton-kernels)
## This is a temporary vendored copy of triton-kernels from the Triton project (MIT License).
## Do not accept contributions to this directory - it should only be updated via scripts/vendor_triton_kernels.py
## This can be removed if and when triton-kernels is published as a separate wheel.
/triton_kernels/** @NVIDIA/trt-llm-oss-compliance
### Docker & Installation Scripts
## These scripts install and pin dependency versions
/docker/common/** @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs @NVIDIA/trt-llm-oss-compliance

View File

@ -1366,6 +1366,9 @@ common-files: &common_files |
triton_backend/tools/whisper/client.py |
)$
# Global exclude pattern for vendored third-party code
exclude: '^triton_kernels/'
default_install_hook_types: [pre-commit, commit-msg]
repos:
- repo: https://github.com/pycqa/isort

View File

@ -62379,7 +62379,7 @@ Copyright 2018- The Hugging Face team. All rights reserved.
- `Homepage`: https://github.com/huggingface/transformers
## triton (3.5.0)
## triton (3.5.1)
### Licenses
License: `MIT License`
@ -62417,6 +62417,40 @@ License: `MIT License`
- `Homepage`: https://github.com/triton-lang/triton/
## triton-kernels (3.5.1)
### Licenses
License: `MIT License`
- `LICENSE` (from triton repository root):
```
Copyright 2018-2020 Philippe Tillet
Copyright 2020-2022 OpenAI
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files
(the "Software"), to deal in the Software without restriction,
including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Software,
and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
```
### URLs
- `Source`: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels
## tritonclient (2.63.0)
### Licenses

View File

@ -107,33 +107,10 @@ Once again, the function call works successfully, this time using a different fu
## Using OpenAI Triton Kernels for MoE
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels; enable them with the steps below:
1. **Build and install Triton** (tested with the commit below):
OpenAI ships a set of Triton kernels optimized for its MoE models.
```bash
git clone https://github.com/triton-lang/triton.git
cd triton
# Specific commit verified with TensorRT-LLM
git checkout f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f
pip install -r python/requirements.txt # build-time dependencies
pip install wheel build
python3 setup.py bdist_wheel
pip install ./dist/*.whl
```
2. **Expose the Triton kernels to TensorRT-LLM**
The kernels are not packaged in the wheel, so set the environment variable `TRITON_ROOT` to your Triton clone:
```bash
export TRITON_ROOT=/local/user/triton
# TensorRT-LLM expects the kernels at:
# $TRITON_ROOT/python/triton_kernels
```
3. **Select Triton as the MoE backend**
**trtllm-serve** (or other similar commands) — add this snippet to the YAML file passed via `--config`:
To use the Triton MoE backend with **trtllm-serve** (or other similar commands), add this snippet to the YAML file passed via `--config`:
```yaml
moe_config:

View File

@ -10,13 +10,15 @@ build-backend = "setuptools.build_meta"
####################################################################################################
[tool.isort]
line_length = 80
known_first_party = ["tensorrt_llm"]
known_third_party = ["triton_kernels"]
[tool.yapf]
based_on_style = "pep8"
column_limit = 80
[tool.codespell]
skip = ".git,3rdparty,tests/integration/test_input_files**,**.jsonl,**.json"
skip = ".git,3rdparty,triton_kernels,tests/integration/test_input_files**,**.jsonl,**.json"
exclude-file = "examples/models/core/whisper/tokenizer.py"
ignore-words-list = "rouge,inout,atleast,strat,nd,subtile,thrid,improbe,NotIn,te,iteract,anythin,tru,Tracin,vEw,dOut"
@ -42,6 +44,7 @@ fix = true
# orders of magnitude faster, so we should move to deprecate `yapf`.
exclude = [
"**3rdparty/**",
"triton_kernels/**",
".devcontainer/make_env.py",
".github/scripts/label_community_user.py",
".github/scripts/pr_checklist_check.py",
@ -1458,6 +1461,7 @@ convention = "google"
[tool.ruff.lint.isort]
known-first-party = ["tensorrt_llm"]
known-third-party = ["triton_kernels"]
split-on-trailing-comma = false
@ -1490,7 +1494,7 @@ disallow_untyped_calls = false
disallow_incomplete_defs = false
disallow_untyped_defs = false
warn_return_any = false
exclude = []
exclude = ["triton_kernels"]
[[tool.mypy.overrides]]

View File

@ -66,7 +66,7 @@ ninja
etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225b61f039457c8072a
blake3
soundfile
triton==3.5.1
triton==3.5.1 # NOTE: if you update this, you must also run scripts/vendor_triton_kernels.py to vendor the new version of triton_kernels
tiktoken
blobfile
openai-harmony==0.0.4

200
scripts/vendor_triton_kernels.py Executable file
View File

@ -0,0 +1,200 @@
#!/usr/bin/env python3
"""Script to vendor triton-kernels into TensorRT-LLM.
This script:
1. Clones the Triton repo at a specific tag to a temp directory
2. Copies the triton_kernels module to the repo root as a top-level package
3. Adds attribution headers to all Python files
4. Copies the LICENSE file from Triton
5. Creates a VERSION file to track the vendored version
6. Creates a README.md with clear copyright attribution
To update to a new version:
python scripts/vendor_triton_kernels.py --tag v3.6.0
"""
import argparse
import shutil
import subprocess
import tempfile
from pathlib import Path
REPO_ROOT = Path(__file__).parent.parent.resolve()
TRITON_REPO_URL = "https://github.com/triton-lang/triton.git"
TRITON_KERNELS_MODULE_PATH = "python/triton_kernels/triton_kernels"
DEST_PATH = REPO_ROOT / "triton_kernels"
VENDORED_NOTICE = "# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY."
ATTRIBUTION_HEADER = f"""\
{VENDORED_NOTICE}
# Source: https://github.com/triton-lang/triton/tree/{{tag}}/{{original_file}}
# Triton is licensed under the MIT License.
"""
def clone_triton(tag: str, dest_dir: str) -> tuple[Path, Path]:
"""Clone the Triton repo at the specified tag. Returns (module_path, repo_root)."""
print(f"Cloning Triton repo at tag {tag}...")
subprocess.run(
["git", "clone", "--depth", "1", "--branch", tag, TRITON_REPO_URL, dest_dir],
check=True,
capture_output=True,
text=True,
)
repo_root = Path(dest_dir)
triton_kernels_module_path = repo_root / TRITON_KERNELS_MODULE_PATH
if not triton_kernels_module_path.exists():
raise RuntimeError(f"triton_kernels module not found at {triton_kernels_module_path}")
return triton_kernels_module_path, repo_root
def add_attribution_header(file_path: Path, tag: str, original_rel_path: str) -> None:
content = file_path.read_text()
# Handle shebang and encoding declarations
lines = content.split("\n")
insert_pos = 0
preserved_lines = []
for i, line in enumerate(lines):
if line.startswith("#!") or line.startswith("# -*-") or line.startswith("# coding"):
preserved_lines.append(line)
insert_pos = i + 1
else:
break
header = ATTRIBUTION_HEADER.format(tag=tag, original_file=original_rel_path)
new_content = "\n".join(preserved_lines)
if preserved_lines:
new_content += "\n"
new_content += header
# Add blank line between header and content if file has content
remaining_content = "\n".join(lines[insert_pos:])
if remaining_content.strip():
new_content += "\n"
new_content += remaining_content
file_path.write_text(new_content)
def copy_triton_kernels(src_path: Path, dest_path: Path, tag: str) -> list[str]:
"""Copy triton_kernels module to destination and add attribution headers."""
print(f"Copying triton_kernels to {dest_path}...")
if dest_path.exists():
print(f" Removing existing {dest_path}")
shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path)
# Add attribution headers to all existing Python files
python_files = []
for py_file in dest_path.rglob("*.py"):
rel_path = py_file.relative_to(dest_path)
original_rel_path = f"{TRITON_KERNELS_MODULE_PATH}/{rel_path}"
python_files.append(str(rel_path))
add_attribution_header(py_file, tag, original_rel_path)
# Create __init__.py files in subdirs that don't have them.
# Triton's upstream code relies on implicit namespace packages (PEP 420), but
# setuptools' find_packages() requires __init__.py to discover subpackages.
for subdir in dest_path.rglob("*"):
if subdir.is_dir():
init_file = subdir / "__init__.py"
if not init_file.exists():
print(f" Creating {init_file.relative_to(dest_path)}")
init_file.write_text(f"{VENDORED_NOTICE}\n")
print(f" Copied triton_kernels module to {dest_path}")
return python_files
def copy_license(triton_repo_root: Path, dest_path: Path) -> None:
"""Copy the Triton LICENSE file."""
print("Copying LICENSE file...")
license_src = triton_repo_root / "LICENSE"
license_dest = dest_path / "LICENSE"
shutil.copy2(license_src, license_dest)
print(f" Copied LICENSE to {license_dest}")
def create_version_file(dest_path: Path, tag: str) -> None:
"""Create a VERSION file to track which version was vendored."""
print("Creating VERSION file...")
version_file = dest_path / "VERSION"
version_content = f"""{tag}
# This file tracks the version of triton-kernels that was vendored.
# To update, run: python scripts/vendor_triton_kernels.py --tag <new-tag>
"""
version_file.write_text(version_content)
print(f" Created {version_file}")
def create_readme(dest_path: Path, tag: str) -> None:
"""Create a README.md with clear copyright attribution."""
print("Creating README.md...")
readme_file = dest_path / "README.md"
readme_content = f"""# Vendored triton_kernels
This directory contains code vendored from the [Triton](https://github.com/triton-lang/triton) project.
| | |
|---|---|
| **Copyright** | The Triton Authors |
| **License** | MIT (see [LICENSE](LICENSE) file in this directory) |
| **Source** | https://github.com/triton-lang/triton/tree/{tag}/python/triton_kernels/triton_kernels |
| **Version** | `{tag}` |
## Attribution
This code is the work of the Triton authors and is included here under the MIT License.
Each Python file includes an attribution header indicating its origin.
## Do Not Edit
This code is vendored verbatim and should not be modified directly.
To update to a newer version, run:
```bash
python scripts/vendor_triton_kernels.py --tag <new-tag>
```
"""
readme_file.write_text(readme_content)
print(f" Created {readme_file}")
def main():
parser = argparse.ArgumentParser(description="Vendor triton-kernels into TensorRT-LLM")
parser.add_argument(
"--tag",
required=True,
help="Triton git tag to vendor from. See the list of tags at https://github.com/triton-lang/triton/tags",
)
args = parser.parse_args()
print(f"Vendoring triton-kernels from Triton {args.tag}")
print(f"Destination: {DEST_PATH}")
with tempfile.TemporaryDirectory() as tmp_dir:
triton_kernels_src, triton_repo_root = clone_triton(args.tag, tmp_dir)
copy_triton_kernels(triton_kernels_src, DEST_PATH, args.tag)
copy_license(triton_repo_root, DEST_PATH)
create_version_file(DEST_PATH, args.tag)
create_readme(DEST_PATH, args.tag)
print("SUCCESS: triton-kernels has been vendored.")
if __name__ == "__main__":
main()

View File

@ -365,6 +365,11 @@ else:
# Ensure rawref is included
package_data.append('runtime/kv_cache_manager_v2/rawref/*.so')
# Add vendored triton_kernels as an explicit top-level package.
# This is vendored from the Triton project and kept at repo root so its
# internal absolute imports (e.g., "from triton_kernels.foo import bar") work.
packages += find_packages(include=["triton_kernels", "triton_kernels.*"])
# https://setuptools.pypa.io/en/latest/references/keywords.html
setup(
name='tensorrt_llm',
@ -392,6 +397,7 @@ setup(
keywords="nvidia tensorrt deeplearning inference",
package_data={
'tensorrt_llm': package_data,
'triton_kernels': ['LICENSE', 'VERSION', 'README.md'],
},
license_files=get_license(),
entry_points={

View File

@ -61,6 +61,43 @@ def _preload_python_lib():
_preload_python_lib()
import sys
from pathlib import Path
def _setup_vendored_triton_kernels():
"""Ensure our vendored triton_kernels takes precedence over any existing installation.
Some environments bundle triton_kernels, which can conflict with our vendored version. This function:
1. Clears any pre-loaded triton_kernels from sys.modules
2. Temporarily adds our package root to sys.path
3. Imports triton_kernels (caching our version in sys.modules)
4. Removes the package root from sys.path
"""
# Clear any pre-loaded triton_kernels from cache
for mod in list(sys.modules.keys()):
if mod == "triton_kernels" or mod.startswith("triton_kernels."):
del sys.modules[mod]
# Temporarily add our package root to sys.path
root = Path(__file__).parent.parent
vendored = root / "triton_kernels"
if not vendored.exists():
raise RuntimeError(
f"Vendored triton_kernels module not found at {vendored}")
should_add_to_path = str(root) not in sys.path
if should_add_to_path:
sys.path.insert(0, str(root))
import triton_kernels # noqa: F401
if should_add_to_path:
sys.path.remove(str(root))
_setup_vendored_triton_kernels()
# Need to import torch before tensorrt_llm library, otherwise some shared binary files
# cannot be found for the public PyTorch, raising errors like:

View File

@ -4,36 +4,15 @@ from typing import Callable, Tuple
import torch
import torch.nn.functional as F
from triton_kernels.matmul_ogs import FlexCtx, FnSpecs, FusedActivation, PrecisionConfig, matmul_ogs
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import RoutingData, routing
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
IS_TRITON_KERNELS_AVAILABLE = True
TRITON_KERNELS_UNAVAILABLE_REASON = ""
try:
from triton_kernels.matmul_ogs import (
FlexCtx,
FnSpecs,
FusedActivation,
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import RoutingData, routing
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from triton_kernels.tensor_details.layout import StridedLayout
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter
except Exception as _e:
IS_TRITON_KERNELS_AVAILABLE = False
TRITON_KERNELS_UNAVAILABLE_REASON = f"{type(_e).__name__}: {_e}"
FlexCtx = FnSpecs = FusedActivation = PrecisionConfig = matmul_ogs = None
InFlexData = RoutingData = routing = swiglu_fn = None
FP4 = convert_layout = wrap_torch_tensor = None
layout = StridedLayout = None
TritonEPRouter = None
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter
# copied from transformers.integrations.mxfp4::swizzle_mxfp4 with minor modification

View File

@ -9,7 +9,6 @@ from tensorrt_llm._torch.auto_deploy.utils.pattern_matcher import (
register_ad_pattern,
)
from ...custom_ops.fused_moe.mxfp4_moe import IS_TRITON_KERNELS_AVAILABLE
from ...utils.module import get_submodule_of_param
from ...utils.node_utils import is_op
from ..interface import BaseTransform, TransformInfo, TransformRegistry
@ -220,11 +219,7 @@ class InsertMXFP4MLP(BaseTransform):
shared_config,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
if (
not qcfg
or qcfg.get("quant_method", "") != self.algo_name
or not IS_TRITON_KERNELS_AVAILABLE
):
if not qcfg or qcfg.get("quant_method", "") != self.algo_name:
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)

View File

@ -1,32 +1,19 @@
from __future__ import annotations
import os
import sys
from typing import Dict, List, NamedTuple, Optional
import torch
import torch.nn as nn
import triton
import triton.language as tl
IS_TRITON_KERNELS_AVAILABLE = False
# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels
# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified.
triton_root = os.getenv('TRITON_ROOT')
if triton_root:
triton_root = os.path.abspath(
os.path.join(triton_root, 'python', 'triton_kernels'))
if os.path.exists(triton_root) and triton_root not in sys.path:
sys.path.insert(0, triton_root)
assert triton.__version__ >= "3.4.0", "Triton kernels are detected but the Triton wheel is too old"
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
IS_TRITON_KERNELS_AVAILABLE = True
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
from triton_kernels.tensor_details import layout
from ...model_config import ModelConfig
from ..linear import TensorParallelMode, load_weight_shard
@ -214,11 +201,16 @@ class TritonUnquantizedFusedMoEMethod(FusedMoEMethodBase):
module.intermediate_size_per_partition,
module.hidden_size,
)
# Bias shapes use the output dimension (last dim) of the transposed weight shapes
w3_w1_bias_shape = (w3_w1_weight_shape[0], w3_w1_weight_shape[2])
w2_bias_shape = (w2_weight_shape[0], w2_weight_shape[2])
super().create_weights(module,
weight_dtype,
w3_w1_weight_shape,
w2_weight_shape,
bias_dtype=torch.float32)
bias_dtype=torch.float32,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)
self.setup_quant_scales(module)
def setup_quant_scales(self, module: torch.nn.Module):
@ -404,12 +396,17 @@ class TritonFP8QDQFusedMoEMethod(TritonUnquantizedFusedMoEMethod):
module.intermediate_size_per_partition,
module.hidden_size,
)
# Bias shapes use the output dimension (last dim) of the transposed weight shapes
w3_w1_bias_shape = (w3_w1_weight_shape[0], w3_w1_weight_shape[2])
w2_bias_shape = (w2_weight_shape[0], w2_weight_shape[2])
FusedMoEMethodBase.create_weights(self,
module,
weight_dtype,
w3_w1_weight_shape,
w2_weight_shape,
bias_dtype=torch.float32)
bias_dtype=torch.float32,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)
fc31_dequant = nn.Parameter(torch.empty(
module.expert_size_per_partition, dtype=torch.float32),
@ -1295,8 +1292,6 @@ class TritonFusedMoE(MoE):
weight_loading_mode=weight_loading_mode,
layer_idx=layer_idx,
)
if not IS_TRITON_KERNELS_AVAILABLE:
raise ImportError("Triton kernels are not available.")
if torch.cuda.get_device_capability()[0] != 9 and self.ep_size > 1:
raise NotImplementedError(
"TritonFusedMoE is only supported on Hopper with EP size > 1.")

View File

@ -623,6 +623,8 @@ def load_activation_scales_fp8_qdq(module: torch.nn.Module, weights: Dict):
load_expert_fc2_input_scale_fp8_qdq(w2_input_scale,
tmp_fc2_input_scale[expert_id])
return tmp_fc31_input_scale.max(), tmp_fc2_input_scale.max()
def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
w1_weight_scale, w3_weight_scale,

View File

@ -4,20 +4,15 @@ from typing import Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, matmul_ogs
from triton_kernels.numerics import InFlexData
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.mapping import Mapping
from ...models.modeling_utils import QuantConfig
# Reuse the common Triton import setup
from .fused_moe.fused_moe_triton import (IS_TRITON_KERNELS_AVAILABLE,
maybe_update_stride,
from .fused_moe.fused_moe_triton import (maybe_update_stride,
swizzle_weight_and_scale)
if IS_TRITON_KERNELS_AVAILABLE:
from triton_kernels.matmul_ogs import (FlexCtx, PrecisionConfig, matmul_ogs)
from triton_kernels.numerics import InFlexData
from .linear import (Linear, LinearMethodBase, TensorParallelMode,
WeightsLoadingConfig, copy_weight, load_weight_shard,
load_weights_fused_gate_up_helper,
@ -383,9 +378,6 @@ class TritonLinear(Linear):
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
):
if not IS_TRITON_KERNELS_AVAILABLE:
raise ImportError("Triton kernels are not available. "
"Please install the required dependencies.")
assert not use_custom_cublas_mm, "TritonLinear does not support custom cublas mm."
super().__init__(

View File

@ -253,10 +253,6 @@ microsoft/phi-4:
accuracy: 90.64
mistralai/Codestral-22B-v0.1:
- accuracy: 67.10
GPT-OSS/BF16:
- accuracy: 90.3
- kv_cache_quant_algo: FP8
accuracy: 90.3
GPT-OSS/120B-MXFP4:
- accuracy: 90.3
- spec_dec_algo: Eagle

View File

@ -51,8 +51,6 @@ from defs.conftest import get_sm_version, is_sm_100f
from tensorrt_llm import LLM
from tensorrt_llm._torch.model_config import MoeLoadBalancerConfig
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
DeepSeekSparseAttentionConfig,
Eagle3DecodingConfig, KvCacheConfig, MoeConfig,
@ -3977,7 +3975,10 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON", "TRTLLM"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRITON", marks=skip_no_hopper), "TRTLLM"])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
(1, 1, 1, False, True, True),
@ -4008,11 +4009,6 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
patch_mpi_pool_session_for_env(mocker,
{"ENABLE_CONFIGURABLE_MOE": env_value})
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("TRITON moe backend is not available.")
if get_sm_version() < 90:
pytest.skip("TRITON moe backend requires Hopper or newer.")
if moe_backend in ["CUTLASS", "TRTLLM"] and get_sm_version() < 100:
pytest.skip(
"CUTLASS or TRTLLM moe backend requires Blackwell or newer.")
@ -4467,11 +4463,12 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("cuda_graph,overlap_scheduler", [
(True, True),
])
@ -4480,8 +4477,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
{"scores_filter": "exact_match,flexible-extract"})
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
@ -4518,11 +4513,12 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
(4, 1, 1, False, True, True),
@ -4551,10 +4547,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
patch_mpi_pool_session_for_env(mocker,
{"ENABLE_CONFIGURABLE_MOE": env_value})
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
MAX_OUTPUT_LEN = 128179
MAX_INPUT_LEN = 32768
@ -4612,11 +4604,12 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
extra_evaluator_kwargs=extra_evaluator_kwargs)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
(8, 1, 1, False, True, True),
@ -4629,9 +4622,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
{"scores_filter": "exact_match,flexible-extract"})
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
@ -4653,6 +4643,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
@pytest.mark.skip_less_device(4)
@skip_no_hopper
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@ -4667,14 +4658,8 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
{"scores_filter": "exact_match,flexible-extract"})
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
dtype=kv_cache_dtype)
@ -4683,11 +4668,12 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
pipeline_parallel_size=pp_size,
moe_expert_parallel_size=ep_size,
kv_cache_config=kv_cache_config,
**pytorch_config,
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
enable_attention_dp=attention_dp,
moe_config=MoeConfig(backend="TRITON"))
with llm:
model_name = "GPT-OSS/BF16"
model_name = "GPT-OSS/120B-MXFP4"
task = GSM8K(model_name)
task.evaluate(llm,
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
@ -4696,11 +4682,12 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize(
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
(2, 1, 1, False, True, True),
@ -4711,10 +4698,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
def test_w4_2gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size,
ep_size, attention_dp, cuda_graph, overlap_scheduler,
mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
@ -4783,16 +4766,13 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
MAX_OUTPUT_LEN = 128179
MAX_INPUT_LEN = 32768
@ -4852,17 +4832,14 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
ids=["overlap_scheduler", "no_overlap_scheduler"])
@pytest.mark.parametrize("one_model", [True, False],
ids=["one_model", "two_model"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
if get_sm_version() == 90:
pytest.skip(
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue for only TP=4"
@ -5044,17 +5021,14 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
ids=["overlap_scheduler", "no_overlap_scheduler"])
@pytest.mark.parametrize("one_model", [True, False],
ids=["one_model", "two_model"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
ids=["cutlass", "trtllm", "triton"])
@pytest.mark.parametrize("moe_backend", [
"CUTLASS",
pytest.param("TRTLLM", marks=skip_pre_blackwell),
pytest.param("TRITON", marks=skip_no_hopper)
],
ids=["cutlass", "trtllm", "triton"])
def test_eagle3_2gpus(self, moe_backend, one_model, overlap_scheduler,
mocker):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
MAX_OUTPUT_LEN = 128179
MAX_INPUT_LEN = 32768

View File

@ -4,11 +4,8 @@ import pytest
import torch
import torch.distributed as dist
from _dist_test_utils import get_device_counts
from utils.util import getSMVersion
from utils.util import skip_no_hopper
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.mxfp4_moe import (
IS_TRITON_KERNELS_AVAILABLE,
)
from tensorrt_llm._torch.auto_deploy.distributed.common import spawn_multiprocess_job
@ -110,14 +107,7 @@ def _run_mxfp4_mlp_ep_dtype_test(num_experts: int, topk: int, rank: int, world_s
torch.testing.assert_close(part_out, ref_out, rtol=5e-2, atol=5e-2, equal_nan=True)
@pytest.mark.skipif(
getSMVersion() != 90,
reason="triton_mxfp4_moe is only supported in Hopper architecture",
)
@pytest.mark.skipif(
not IS_TRITON_KERNELS_AVAILABLE,
reason="triton_kernels unavailable",
)
@skip_no_hopper
@pytest.mark.parametrize("num_experts", [6, 8])
@pytest.mark.parametrize("topk", [4]) # must be <= num_experts
@pytest.mark.parametrize("device_count", get_device_counts())

View File

@ -5,10 +5,9 @@ import shutil
import pytest
from transformers import AutoTokenizer
from utils.llm_data import llm_models_root
from utils.util import skip_no_hopper
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig
configs = """
@ -48,11 +47,10 @@ def dump_config_json(dst_dir):
json.dump(json_configs, f, indent=2, ensure_ascii=False)
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS", pytest.param("TRITON", marks=skip_no_hopper)])
def test_gpt_oss_trtllmgen(moe_backend):
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
prompts = [
"How are you?",
"Hello, my name is",

View File

@ -19,9 +19,8 @@ from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from transformers.configuration_utils import PretrainedConfig
from utils.util import (check_accuracy, skip_blackwell, skip_blackwell_geforce,
skip_neither_ada_nor_hopper_unittest,
skip_non_hopper_unittest, skip_pre_blackwell,
skip_pre_hopper)
skip_neither_ada_nor_hopper_unittest, skip_no_hopper,
skip_pre_blackwell, skip_pre_hopper)
from tensorrt_llm._torch.autotuner import AutoTuner, autotune
from tensorrt_llm._torch.model_config import ModelConfig
@ -41,8 +40,6 @@ from tensorrt_llm._torch.modules.fused_moe import (
from tensorrt_llm._torch.modules.fused_moe.quantization import \
NVFP4CutlassFusedMoEMethod
# isort: on
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._utils import get_sm_version, mpi_rank
from tensorrt_llm.mapping import Mapping
@ -92,8 +89,8 @@ def test_fused_moe(moe_backend,
mapping=None):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
if get_sm_version() != 90:
pytest.skip("TRITON moe backend is only supported on Hopper")
if dtype != torch.bfloat16:
pytest.skip("Unsupported for TritonFusedMoE")
if routing_cls != RenormalizeMoeRoutingMethod:
@ -192,9 +189,9 @@ def test_fused_moe(moe_backend,
# Evaluate outputs
torch.cuda.synchronize()
# There can be one off mismatch in the outputs due to different kernel implementations
# Here we check 99% of the outputs are within the tolerance
# Here we check most of the outputs are within the tolerance
# The CutlassFusedMoE case fails as well without this change on H100 for bf16
check_accuracy(output, ref_output, rtol=0.2, atol=0.2, percent=0.984)
check_accuracy(output, ref_output, rtol=0.2, atol=0.2, percent=0.975)
m //= 2
@ -514,7 +511,9 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
@skip_pre_hopper
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"])
@pytest.mark.parametrize(
"moe_backend",
["CUTLASS", pytest.param("TRITON", marks=skip_no_hopper)])
@pytest.mark.parametrize("routing_cls",
[DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod])
@pytest.mark.parametrize("bias", [True, False])
@ -522,8 +521,6 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
def test_fused_moe_fp8(moe_backend, dtype, routing_cls, bias):
if moe_backend == "TRITON":
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
if dtype != torch.bfloat16:
pytest.skip("Unsupported for TritonFusedMoE")
if routing_cls != RenormalizeMoeRoutingMethod:
@ -632,19 +629,30 @@ def test_fused_moe_fp8(moe_backend, dtype, routing_cls, bias):
with torch.inference_mode(), autotune():
fused_moe.forward(x, router_logits)
# Explicitly capture context for kernel testing
with AutoTuner.get().capture() as all_tactics, torch.inference_mode():
output = fused_moe.forward(x, router_logits)
# Test all kernel tactics
for tactic in all_tactics:
with AutoTuner.get().replay(tactic), torch.inference_mode():
# TRITON backend uses Triton kernels which don't register with AutoTuner
if moe_backend == "TRITON":
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
check_accuracy(output,
ref_output,
rtol=0.04,
atol=0.1,
percent=0.99)
check_accuracy(output,
ref_output,
rtol=0.04,
atol=0.1,
percent=0.99)
else:
# Explicitly capture context for kernel testing
with AutoTuner.get().capture() as all_tactics, torch.inference_mode(
):
output = fused_moe.forward(x, router_logits)
# Test all kernel tactics
for tactic in all_tactics:
with AutoTuner.get().replay(tactic), torch.inference_mode():
output = fused_moe.forward(x, router_logits)
check_accuracy(output,
ref_output,
rtol=0.04,
atol=0.1,
percent=0.99)
def set_tensor_value_2(x, num_row, num_cols):
@ -1174,7 +1182,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
return True
@skip_non_hopper_unittest
@skip_no_hopper
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
product(
@ -1306,7 +1314,7 @@ def test_fused_moe_fp8_blockwise_cutlass(dtype,
return True
@skip_non_hopper_unittest
@skip_no_hopper
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@ -2526,7 +2534,7 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
check_accuracy(output, ref_output, rtol=1e-2, atol=0.1, percent=0.99)
@skip_pre_hopper
@skip_no_hopper
@pytest.mark.parametrize("experts", [8, 128])
@pytest.mark.parametrize(
"hidden_size, intermediate_size",
@ -2542,12 +2550,8 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
@pytest.mark.parametrize("dynamic_quant", [True, False])
def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size,
fp8_activation, bias, dynamic_quant):
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
if torch.cuda.get_device_capability()[0] < 10 and fp8_activation:
if fp8_activation:
pytest.skip("Latest Triton requires BF16 activation on Hopper")
if torch.cuda.get_device_capability()[0] >= 10 and not fp8_activation:
pytest.skip("Latest Triton requires FP8 activation on Blackwell")
mapping = Mapping()
mapping.rank = mpi_rank()

View File

@ -5,10 +5,8 @@ import cloudpickle
import pytest
import torch
from mpi4py import MPI
from utils.util import check_accuracy, skip_pre_hopper
from utils.util import check_accuracy, skip_no_hopper
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.triton_linear import TritonLinear
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
@ -21,11 +19,10 @@ MPI.pickle.__init__(
)
@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear])
@pytest.mark.parametrize(
"linear_cls",
[Linear, pytest.param(TritonLinear, marks=skip_no_hopper)])
def test_linear_unquantized(linear_cls):
if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear:
pytest.skip("Triton kernels are not available")
torch.manual_seed(0)
torch.cuda.manual_seed(0)
num_tokens = 128
@ -56,11 +53,10 @@ def test_linear_unquantized(linear_cls):
check_accuracy(actual_c, reference_c, atol=0.01, rtol=0.01, percent=0.99)
@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear])
@pytest.mark.parametrize(
"linear_cls",
[Linear, pytest.param(TritonLinear, marks=skip_no_hopper)])
def test_linear_fp8qdq(linear_cls):
if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear:
pytest.skip("Triton kernels are not available")
torch.manual_seed(0)
torch.cuda.manual_seed(0)
num_tokens = 128
@ -100,18 +96,12 @@ def test_linear_fp8qdq(linear_cls):
percent=0.99)
@skip_pre_hopper
@skip_no_hopper
@pytest.mark.parametrize("activation_dtype",
[torch.bfloat16, torch.float8_e4m3fn])
def test_linear_mxfp4(activation_dtype):
if not IS_TRITON_KERNELS_AVAILABLE:
pytest.skip("Triton kernels are not available")
if torch.cuda.get_device_capability(
)[0] < 10 and activation_dtype == torch.float8_e4m3fn:
if activation_dtype == torch.float8_e4m3fn:
pytest.skip("Latest Triton requires BF16 activation on Hopper")
if torch.cuda.get_device_capability(
)[0] >= 10 and activation_dtype == torch.bfloat16:
pytest.skip("Latest Triton requires FP8 activation on Blackwell")
dtype = torch.bfloat16
num_tokens = 128

View File

@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for vendored triton_kernels package.
The triton_kernels package is vendored from the Triton project to provide
optimized kernels. These tests verify that the vendoring mechanism works
correctly and that our version takes precedence over any external installation.
"""
import unittest
from pathlib import Path
class TestTritonKernelsVendoring(unittest.TestCase):
def test_triton_kernels_is_vendored_version(self):
"""Verify we're using the vendored version (has VERSION and LICENSE files)."""
import tensorrt_llm # noqa: F401, I001
import triton_kernels
triton_kernels_path = Path(triton_kernels.__file__).parent
# VERSION file is added by our vendor script and not present in external installations
version_file = triton_kernels_path / "VERSION"
self.assertTrue(
version_file.exists(),
f"VERSION file not found at {version_file}. "
"This suggests an external triton_kernels is being used instead of our vendored version.",
)
# LICENSE file should also be present for compliance
license_file = triton_kernels_path / "LICENSE"
self.assertTrue(license_file.exists(), f"LICENSE file not found at {license_file}.")
def test_version_matches_requirements(self):
"""Verify vendored triton_kernels VERSION matches triton version in requirements.txt."""
import re
repo_root = Path(__file__).parent.parent.parent.parent
version_file = repo_root / "triton_kernels" / "VERSION"
vendored_version = version_file.read_text().strip().split()[0].lstrip("v")
requirements_file = repo_root / "requirements.txt"
requirements_text = requirements_file.read_text()
match = re.search(r"^triton==([^\s#]+)", requirements_text, re.MULTILINE)
self.assertIsNotNone(match, "Could not find triton version in requirements.txt")
requirements_version = match.group(1)
self.assertEqual(
vendored_version,
requirements_version,
f"Vendored triton_kernels version ({vendored_version}) does not match "
f"triton version in requirements.txt ({requirements_version}). "
"To update the vendored triton_kernels, run: python scripts/vendor_triton_kernels.py "
f"--tag v{requirements_version}",
)
if __name__ == "__main__":
unittest.main()

View File

@ -106,6 +106,9 @@ skip_pre_ada = pytest.mark.skipif(
skip_pre_hopper = pytest.mark.skipif(
getSMVersion() < 90,
reason="This test is not supported in pre-Hopper architecture")
skip_no_hopper = pytest.mark.skipif(
getSMVersion() != 90,
reason="This test is only supported in Hopper architecture")
skip_pre_blackwell = pytest.mark.skipif(
getSMVersion() < 100,
reason="This test is not supported in pre-Blackwell architecture")

23
triton_kernels/LICENSE Normal file
View File

@ -0,0 +1,23 @@
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

24
triton_kernels/README.md Normal file
View File

@ -0,0 +1,24 @@
# Vendored triton_kernels
This directory contains code vendored from the [Triton](https://github.com/triton-lang/triton) project.
| | |
|---|---|
| **Copyright** | The Triton Authors |
| **License** | MIT (see [LICENSE](LICENSE) file in this directory) |
| **Source** | https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels |
| **Version** | `v3.5.1` |
## Attribution
This code is the work of the Triton authors and is included here under the MIT License.
Each Python file includes an attribution header indicating its origin.
## Do Not Edit
This code is vendored verbatim and should not be modified directly.
To update to a newer version, run:
```bash
python scripts/vendor_triton_kernels.py --tag <new-tag>
```

3
triton_kernels/VERSION Normal file
View File

@ -0,0 +1,3 @@
v3.5.1
# This file tracks the version of triton-kernels that was vendored.
# To update, run: python scripts/vendor_triton_kernels.py --tag <new-tag>

View File

@ -0,0 +1,3 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/__init__.py
# Triton is licensed under the MIT License.

View File

@ -0,0 +1,73 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/compaction.py
# Triton is licensed under the MIT License.
import torch
from .compaction_details._masked_compaction import _masked_compaction
from .tensor import Bitmatrix
def compaction(yv, yi, bitmask, sentinel=-1):
"""
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
Only the elements whose index appears among the active bits of *bitmask*
are kept; the rest are replaced by *sentinel*. Kept elements preserve
their original left-to-right order.
Parameters
----------
yv : torch.Tensor, shape (B, K)
Values tensor.
yi : torch.Tensor, shape (B, K), dtype torch.long
Integer indices (0 index < 32) associated with *yv*.
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
Per-row mask of active indices. See the in-place version for details.
sentinel : int, default -1
Value written into dropped positions of the returned tensors.
Returns
-------
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
New tensors with the same dtype/device as the inputs.
"""
n_rows, n_cols = yi.shape
ret_yv = torch.empty_like(yv)
ret_yi = torch.empty_like(yi)
if isinstance(bitmask, Bitmatrix):
bitmask = bitmask.storage.data
_masked_compaction[(n_rows, )](
yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
ret_yv, ret_yi, # outputs
sentinel, # sentinel
K=n_cols # constants
)
return ret_yv, ret_yi
def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
"""
reference implementation of `masked_compact`
"""
B, K = yi.shape
device = yi.device
# Expand bitmask to a boolean matrix of active bits (B, 32)
w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype))
bits = (bitmask.unsqueeze(-1) & w) != 0
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
# For every yi element decide whether it should be kept
keep = mask.gather(1, yi.long())
# Build a stable permutation that brings all "keep" items forward
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
# Reorder tensors according to above permutation
yi_sorted = yi.gather(1, order)
yv_sorted = yv.gather(1, order)
# fill relevant positions with sentinel
keep_sorted = keep.gather(1, order)
yi_sorted[~keep_sorted] = sentinel
yv_sorted[~keep_sorted] = sentinel
return yv_sorted, yi_sorted

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,24 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/compaction_details/_masked_compaction.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
@triton.jit
def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
pid_m = tl.program_id(0)
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
div = yi // 32
rem = yi % 32
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
active_flags = active_bits.to(tl.int1)
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
write_indx = exc_cumsum + rev_arange
yv = tl.where(active_flags, yv, sentinel)
yi = tl.where(active_flags, yi, sentinel)
tl.store(RetYv + pid_m * K + write_indx, yv)
tl.store(RetYi + pid_m * K + write_indx, yi)

View File

@ -0,0 +1,613 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs.py
# Triton is licensed under the MIT License.
# isort: off
# fmt: off
from dataclasses import dataclass
import itertools
import sys
import torch
import triton
from enum import Enum, auto
import math
# utilities
from triton_kernels import target_info
from triton_kernels.numerics import InFlexData, OutFlexData
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.target_info import is_cuda
# details
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
from .matmul_ogs_details._reduce_grouped import _reduce_grouped
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints, InapplicableConstraint
from .specialize import specialize
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
@dataclass(frozen=True)
class FnSpecs:
name: str
fn: "triton.runtime.jit.JITFunction"
fn_arg_names: tuple[str]
fn_arg_do_not_specialize: tuple[str] = tuple()
@staticmethod
def default():
return FnSpecs("dflt", None, tuple())
@dataclass(frozen=True)
class FusedActivation:
specs: FnSpecs = FnSpecs.default()
fn_args: tuple[object] = tuple()
reduction_n: int = 1
@dataclass(frozen=True)
class Epilogue:
specs: FnSpecs = FnSpecs.default()
fn_arg_values_matmul: tuple[object] = tuple()
fn_arg_values_finalize: tuple[object] = tuple()
effective_itemsize: float = None
class FnName(Enum):
QUANTIZE_MXFP8 = auto()
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
_kernels = dict()
def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
global _kernels
key = (fused_activation.name, epilogue.name)
if key in _kernels:
return _kernels[key]
spec_constants = {
"ACTIVATION_FN": fused_activation.fn,
"EPILOGUE_FN": epilogue.fn,
}
spec_tuples = {
"activation_fn_args": fused_activation.fn_arg_names,
"epilogue_fn_args": epilogue.fn_arg_names,
}
do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
import types
module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
sys.modules[module.__name__] = module
module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
module._reduce_grouped = specialize(_reduce_grouped, module, spec_constants, spec_tuples,
do_not_specialize=do_not_specialize)
_kernels[key] = module
return module
# -----------------------------------------------------------------------------
# Matrix Multiplication + Outer Gather/Scatter
# -----------------------------------------------------------------------------
def can_overflow_int32(tensor: torch.Tensor):
max_int32 = (1 << 31) - 1
offset = 0
for i in range(tensor.ndim):
offset += (tensor.shape[i] - 1) * tensor.stride(i)
return offset > max_int32
def should_upcast_indices(*args):
return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
# ---------------------
# Numerics
# ---------------------
# fmt: off
@dataclass(frozen=True)
class FlexCtx:
lhs_data: InFlexData = InFlexData()
rhs_data: InFlexData = InFlexData()
out_data: OutFlexData = OutFlexData()
@dataclass
class PrecisionConfig:
max_num_imprecise_acc: int = None
allow_tf32: bool = True
flex_ctx: FlexCtx = FlexCtx()
acc_scale: int = 1.0
flexpoint_saturate_inf: bool = False
report_quantization_err_fn: callable = None
act_scale: Tensor | None = None
weight_scale: Tensor| None = None
out_scale: Tensor | None = None
out_dtype: torch.dtype = None
enforce_bitwise_invariance: bool = False
# TODO: merge in opt_flags
def get_swap_xw(precision_config, opt_flags):
if target_info.cuda_capability_geq(10, 0):
return precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
return False
# ---------------------
# Allocation
# ---------------------
@dataclass
class MatmulAllocation:
device: str
output: tuple[tuple[int], torch.dtype]
scratchpads: dict[str, tuple]
def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags):
# ---- output ------
N = w.shape[-1]
# by default - M is number of rows in the activations
M = x.shape[-2]
# if the activations are gathered, then M is number of gather indices
if gather_indx is not None:
M = gather_indx.src_indx.shape[0]
# final output
if routing_data.n_expts_act == 1 or scatter_indx is None:
y_rows = M
else:
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
y_rows = Mc
batch_dim = x.shape[0] if x.ndim == 3 else 1
out_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
out_dtype = precision_config.out_dtype or x.dtype
output = (out_shape, out_dtype)
# ---- scratchpad -----#
scratchpad = dict()
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
if "matmul" in scratchpad and precision_config.out_scale is not None:
scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
return MatmulAllocation(x.device, output, scratchpad)
def apply_allocation(allocation: MatmulAllocation, output):
ret = dict()
if output is None:
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
else:
assert output.shape == allocation.output[0]
ret["output"] = output[None, :, :]
ret["scratchpad"] = {
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
for k, v in allocation.scratchpads.items()
}
return ret
# -----------------------------------------------------------------------------
# Canonicalize
# -----------------------------------------------------------------------------
# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
# we can canonicalize storages to make the implementation more uniform
def _canonicalize_storage(storage, out_ndim, flex_data):
assert out_ndim >= storage.data.ndim
# Need to use as_strided instead of view because for a tensor with
# shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
# > t = torch.randn(2, 5, 1).mT
# > t_view = t.view(t.shape)
# > t.stride(), t_view.stride()
# ((5, 1, 1), (5, 5, 1))
# Our check t_view is col-wise fails since t_view.stride(-2) != 1
# This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
new_storage_view = storage.data.view(new_storage_shape)
new_storage_stride = [new_storage_view.stride(0)] * (out_ndim - storage.data.ndim) + list(storage.data.stride())
new_storage_data = storage.data.as_strided(new_storage_shape, new_storage_stride)
if flex_data is not None:
new_storage_data = flex_data.reinterpret(new_storage_data)
return Storage(new_storage_data, storage.layout)
#
def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_mx_scale: torch.Tensor,
fused_activation, epilogue,
x_flex: InFlexData | None = None,
out_flex: OutFlexData | None = None, x_mx_scale: torch.Tensor | None = None,
out_dtype: bool = None, flexpoint_saturate_inf: bool = False):
"""
In-place grouped row reduction.
Arguments
- x: Tensor[AnyFloat] of shape [(num_groups * K), N]
- indx: Tensor[Int] of shape [num_groups, K]
Description
For each group g in [0, num_groups), this routine sums the K rows of `x`
specified by `indx[g, :]` and overwrites the row corresponding to the first
valid (non-negative) index with the per-group sum. Accumulation is performed
in float32 for numerical stability, and the result is written back in the
dtype of `x`.
Behavior and edge cases
- Invalid (-1) entries are skipped during accumulation and do not generate
memory traffic. If a group has no valid entries, nothing is written for
that group.
- Reduction is performed tile-by-tile along the N dimension within a single
kernel launch (persistent along N) to minimize launch overhead.
Performance notes
- Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
of length N per group.
Returns
- The input tensor `x` (modified in place).
"""
if indx is None and x.shape[0] == 1:
return x.squeeze(0), None
if indx is not None:
num_groups = indx.shape[0]
else:
num_groups = x.shape[-2]
if x_flex is None:
x_flex = InFlexData()
if out_flex is None:
out_flex = OutFlexData()
K = 1 if indx is None else indx.shape[1]
out_dtype = x.dtype if out_dtype is None else out_dtype
assert x.shape[-1] % fused_activation.reduction_n == 0
BLOCK_N = 512
# Resolve scalar flex scales (may be None)
x_expected_scale = None if x_flex is None else x_flex.scale
out_expected_scale = None if out_flex is None else out_flex.expected_scale
out_actual_scale = None if out_flex is None else out_flex.actual_scale
out_checksum_scale = None if out_flex is None else out_flex.checksum_scale
# Resolve MXFP output scale row stride
stride_mxb = 0 if x_mx_scale is None else x_mx_scale.stride(0)
stride_mxs = 0 if x_mx_scale is None else x_mx_scale.stride(1)
stride_omxs = 0 if out_mx_scale is None else out_mx_scale.stride(0)
kernels = get_kernels(epilogue.specs, fused_activation.specs)
kernels._reduce_grouped[(num_groups, )](
x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
x_expected_scale, # scalar input scale
out_flex.reinterpret(out), out.stride(1), out.stride(2), #
out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
x.shape[0], x.shape[-1], #
x_mx_scale, stride_mxb, stride_mxs, #
out_mx_scale, stride_omxs, #
*fused_activation.fn_args, fused_activation.reduction_n,
*epilogue.fn_arg_values_finalize,
HAS_IN_MX_SCALE=x_mx_scale is not None, HAS_OUT_MX_SCALE=out_mx_scale is not None,
FLEXPOINT_SATURATE_INF=flexpoint_saturate_inf, #
BLOCK_N=BLOCK_N, K=K, #
num_warps=1, #
)
return out, out_mx_scale
# -----------------------------------------------------------------------------
# Triton Implementation
# -----------------------------------------------------------------------------
def matmul_ogs_set_idle_sms(num_idle_sms):
"""
persistent kernels will leave `num_idle_sms` idle
"""
update_opt_flags_constraints({"idle_sms": num_idle_sms})
def matmul_ogs(x, w, bias,
routing_data: RoutingData | None = None,
gather_indx: GatherIndx | None = None,
scatter_indx: ScatterIndx | None = None,
precision_config: PrecisionConfig | None = None,
betas: torch.Tensor | None = None,
gammas: torch.Tensor | None = None,
out_alpha: float | None = None,
y: torch.Tensor | None = None,
fused_activation: FusedActivation | None = None,
epilogue: Epilogue | None = None,
):
"""
Y[:, :] = 0.
for e in num_experts:
Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
"""
is_input_batched = x.ndim == 3
if is_input_batched:
assert gather_indx is None, "gather not supported in batched mode"
assert scatter_indx is None, "scatter not supported in batched mode"
assert routing_data is None, "routing not supported in batched mode"
assert w.ndim == 3 and w.shape[0] == x.shape[0]
# canonicalize inputs
if precision_config is None:
precision_config = PrecisionConfig()
if fused_activation is None:
fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
if epilogue is None:
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
if routing_data is None:
routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
# unpack scales
w_scale = precision_config.weight_scale
w_has_mx = w_scale is not None
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
if not isinstance(w, Tensor):
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
w = wrap_torch_tensor(w, dtype=dtype)
if w_scale is not None and not isinstance(w_scale, Tensor):
w_scale = Tensor(w_scale)
if w_scale is not None:
w_scale.storage.data = w_scale.data.view(torch.uint8)
w_scale.dtype = torch.uint8
x_scale = precision_config.act_scale
x_has_mx = x_scale is not None
if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
if x_scale is not None and not isinstance(x_scale, Tensor):
x_scale = Tensor(x_scale)
if not isinstance(x, Tensor):
x = Tensor(x, dtype=x.dtype)
# determine shapes
has_gather = gather_indx is not None
has_scatter = scatter_indx is not None
is_ragged = routing_data.expt_hist is not None
M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
K, N = w.shape[-2:]
assert K == x.shape[-1]
if x.ndim == 3 and w.ndim == 3:
assert x.shape[0] == w.shape[0]
# compute optimization flags
out_dtype = precision_config.out_dtype or x.dtype
can_use_tma = x.numel() > 0 and x.storage.is_tma_compliant() and \
w.numel() > 0 and w.storage.is_tma_compliant() and \
(w_scale is None or w_scale.storage.is_tma_compliant())
# hopper w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
)
if not can_use_fused_scatter and opt_flags.fused_scatter:
raise InapplicableConstraint("Fused scatter is not supported")
if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
# fused activation
matmul_fused_activation = fused_activation
reduce_fused_activation = FusedActivation()
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
matmul_fused_activation, reduce_fused_activation = reduce_fused_activation, matmul_fused_activation
# allocate output/scratchpad memory
allocation = init_allocation(x, w, precision_config, fused_activation,
routing_data, gather_indx, scatter_indx, opt_flags)
memory = apply_allocation(allocation, y)
# early exit
if batch_size * M * N == 0:
ret = memory["output"].squeeze(0)
if not is_input_batched:
ret = ret.squeeze(0)
return ret
# TMA descriptors require a global memory allocation
if opt_flags.is_persistent:
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
# Intermediate tensors and postprocess kernels for each situation
has_scratchpad = "matmul" in memory["scratchpad"]
# Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
out_matmul = memory["scratchpad"].get("matmul", memory["output"])
out_matmul_flex = OutFlexData() if out_matmul.dtype == torch.float32 else precision_config.flex_ctx.out_data
# Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
out_matmul_scale = precision_config.out_scale
if out_matmul_scale is not None:
out_matmul_scale = out_matmul_scale.data.view(torch.uint8)
if has_scratchpad and "mx_out_scale" in memory["scratchpad"]:
out_matmul_scale = memory["scratchpad"]["mx_out_scale"]
out_matmul_has_mx = out_matmul_scale is not None and out_matmul.element_size() == 1
# matrix multiplication
flex = precision_config.flex_ctx
bias_stride = None if bias is None else bias.stride(0)
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
# moe metadata
expt_data = routing_data.expt_data
block_m = opt_flags.block_m
expt_hist = None if expt_data is None else expt_data.hist
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
# spmd grid
grid_m = triton.cdiv(M, opt_flags.block_m)
if expt_block_pid_map is not None:
grid_m = routing_data.n_blocks(M, opt_flags.block_m)
grid_n = triton.cdiv(N, opt_flags.block_n)
max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
# canonicalize storage
has_gather_tma = has_gather and target_info.has_tma_gather()
has_scatter_tma = opt_flags.fused_scatter and target_info.has_tma_gather()
y = wrap_torch_tensor(out_matmul.view(math.prod(out_matmul.shape[:-1]), out_matmul.shape[-1]) if opt_flags.fused_scatter else out_matmul.view(math.prod(out_matmul.shape[:-2]), *out_matmul.shape[-2:]))
x_storage = _canonicalize_storage(x.storage, 2 if has_gather_tma else 3, flex.lhs_data)
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
y_storage = _canonicalize_storage(y.storage, 2 if has_scatter_tma else 3, flex.out_data)
# create tma descriptor for x
x_has_tma = opt_flags.is_persistent and (has_gather_tma or not has_gather)
x_tma_block_size = [1, opt_flags.block_k] if has_gather_tma else [1, opt_flags.block_m, opt_flags.block_k]
x_tma_mode = None if not x_has_tma else "ragged" if is_ragged and not has_gather_tma else "dense"
x_tensor_or_tma = x_storage.make_tma(x_tma_block_size, x_tma_mode) if x_has_tma else x_storage.data
# create tma descriptor for y
y_has_tma = opt_flags.is_persistent and (has_scatter_tma or not opt_flags.fused_scatter)
block_n = opt_flags.block_n // opt_flags.epilogue_subtile // matmul_fused_activation.reduction_n
y_tma_block_size = [1, block_n] if has_scatter_tma else [1, opt_flags.block_m, block_n]
y_tma_mode = None if not y_has_tma else "ragged" if is_ragged and not has_scatter_tma else "dense"
y_tensor_or_tma = y_storage.make_tma(y_tma_block_size, y_tma_mode) if y_has_tma else y_storage.data
# create tma descriptor for w
w_has_tma = opt_flags.is_persistent
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data
# create tma descriptor for w_scale
w_scale_tensor_or_tma = w_scale
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale
# canonicalize strides
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
out_matmul_scale_strides = out_matmul_scale.stride() if out_matmul_has_mx else (None, None, None, None)
out_matmul_scale_strides = (0, ) * (3 - len(out_matmul_scale_strides)) + out_matmul_scale_strides
# launch kernel
kernels = get_kernels(epilogue.specs, matmul_fused_activation.specs)
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
# w_transpose = w_storage.data.stride()[-1] != 1
w_transpose = w_storage.data.stride()[-2] == 1
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
y_tensor_or_tma, y_storage.data, *out_matmul.stride(),
*((None, out_matmul_scale, None) if out_matmul_has_mx else out_matmul_flex),
*out_matmul_scale_strides[-3:],
x_tensor_or_tma, x_storage.data, *x_strides,
flex.lhs_data.scale,
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
w_tensor_or_tma, w_storage.data, *w_storage.data.stride(), w_transpose,
flex.rhs_data.scale,
w_scale_tensor_or_tma, *w_scale_strides,
bias, bias_stride,
x.shape[-2],
x.shape[-2] if routing_data.expt_hist is None else None,
N, K,
betas, gammas,
None if gather_indx is None else gather_indx.src_indx,
None if scatter_indx is None else scatter_indx.src_indx,
num_indx,
None if not opt_flags.fused_scatter else scatter_indx.dst_indx,
None if not opt_flags.fused_scatter else scatter_indx.dst_indx.shape[0],
expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
batch_size, grid_m, grid_n,
out_alpha,
*matmul_fused_activation.fn_args, matmul_fused_activation.reduction_n,
*epilogue.fn_arg_values_matmul,
routing_data.n_expts_tot, routing_data.n_expts_act,
precision_config.max_num_imprecise_acc,
precision_config.allow_tf32,
precision_config.flexpoint_saturate_inf,
flex.rhs_data.is_per_batch,
opt_flags.block_m,
opt_flags.block_n,
opt_flags.block_k,
opt_flags.group_m,
XCD_SWIZZLE=opt_flags.xcd_swizzle,
SWIZZLE_MX_VALUE=w.storage.layout.name,
SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
SPLIT_K=opt_flags.split_k,
EVEN_K=K % opt_flags.block_k == 0,
W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
num_warps=opt_flags.num_warps,
num_stages=opt_flags.num_stages,
arch=opt_flags.arch,
UPCAST_INDICES=should_upcast_indices(x, w, out_matmul),
X_TMA_MODE=x_tma_mode,
Y_TMA_MODE=y_tma_mode,
SWAP_XW=get_swap_xw(precision_config, opt_flags),
IS_EPILOGUE_QUANT_MXFP8=epilogue.specs.name == FnName.QUANTIZE_MXFP8.name,
NUM_SMS = grid if opt_flags.is_persistent else 0,
**opt_flags.target_kernel_kwargs)
# Build grouped reduction inputs in a uniform way
group_indx = None if scatter_indx is None or opt_flags.fused_scatter else scatter_indx.src_indx.view(-1, routing_data.n_expts_act)
out_final, out_final_mx_scale = reduce_grouped(
out_matmul,
group_indx,
memory["output"].squeeze(0),
precision_config.out_scale,
reduce_fused_activation,
epilogue,
x_flex=InFlexData(dtype=out_matmul_flex.dtype, scale=out_matmul_flex.expected_scale),
out_flex=precision_config.flex_ctx.out_data,
x_mx_scale=out_matmul_scale.squeeze(1) if out_matmul_has_mx else None,
out_dtype=memory["output"].dtype,
flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
)
if not is_input_batched:
out_final = out_final.squeeze(0)
if out_final_mx_scale is not None:
precision_config.out_scale = out_final_mx_scale
return out_final
# -----------------------------------------------------------------------------
# Reference Implementation
# -----------------------------------------------------------------------------
def matmul_ogs_torch(x, w, bias,
routing_data: RoutingData = None,
gather_indx: GatherIndx = None,
scatter_indx: ScatterIndx = None,
precision_config: PrecisionConfig = None,
betas = None,
gammas = None,
round_x = None, round_y = None,
):
is_input_batched = x.ndim == 3
assert x.dtype.itemsize > 1
assert w.dtype.itemsize > 1
if is_input_batched:
assert gather_indx is None, "gather not supported in batched mode"
assert scatter_indx is None, "scatter not supported in batched mode"
assert routing_data is None, "routing not supported in batched mode"
assert w.ndim == 3 and w.shape[0] == x.shape[0]
if round_x is None:
round_x = lambda x, idx: x
if round_y is None:
round_y = lambda x: x
if bias is not None and bias.ndim == 1:
bias = bias.view(1, *bias.shape)
if w.ndim == 2:
w = w.view(1, *w.shape)
if x.ndim == 2:
x = x.view(1, *x.shape)
if routing_data is None:
routing_data = RoutingData(None, None, w.shape[0], 1)
n_expts_act = routing_data.n_expts_act
# memory offsets
if routing_data.n_expts_tot > 1 and not is_input_batched:
sizes = routing_data.expt_hist
off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
off[1:] = torch.cumsum(sizes, 0)
offs = list(itertools.pairwise(off))
else:
offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
# compute
n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
for i, (lo, hi) in enumerate(offs):
if gather_indx is None:
idx = torch.arange(lo, hi, device=x.device)
else:
idx = gather_indx.src_indx[lo:hi] // n_expts_act
batch = i if is_input_batched else 0
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
w[i].float())
if bias is not None:
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
if gammas is not None:
out *= gammas[lo:hi, None]
y[batch, lo:hi, :] = round_y(out)
if not is_input_batched:
y = y.view(y.shape[1], y.shape[2])
if scatter_indx is None:
return y
# accumulate output from all experts
n_rows = y.shape[0] // n_expts_act
out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
for i, (lo, hi) in enumerate(offs):
dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
msk = dst_idx != -1
out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
return out

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,169 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py
# Triton is licensed under the MIT License.
import torch
import triton
import triton.language as tl
# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------
@triton.constexpr_function
def get_scaled_dot_format_string(dtype: tl.dtype):
mapping = {
tl.float16: "fp16",
tl.bfloat16: "bf16",
tl.uint8: "e2m1",
tl.float8e4nv: "e4m3",
tl.float8e5: "e5m2",
}
return mapping[dtype]
@triton.jit
def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
"""
Swizzle the program id based on integer XCD_SWIZZLE.
This is useful for reording how blocks are ordered. A scheduler may, for example,
assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
the same hardware unit.
"""
# Number of pids per group in the new arrangement
pids_per_group = domain_size // XCD_SWIZZLE
extra_pid_groups = domain_size % XCD_SWIZZLE
# Compute current current and local pid within the group
group = pid % XCD_SWIZZLE
local_pid = pid // XCD_SWIZZLE
# Calculate new pid based on the new grouping
new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
return new_pid
@triton.jit
def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
tl.assume(group_size >= 0)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
return pid_m, pid_n
def make_matmul_repr(base_name, order):
def matmul_repr(specialization):
signature = specialization.signature
constants = specialization.constants
reorder = lambda L: [L[i] for i in order]
layout = lambda stride: "N" if stride in constants else "T"
def convert_dtype(dtype):
if "tensordesc" in dtype:
ret = convert_dtype(dtype.split("<")[1].split("[")[0])
return ret
elif "u8" in dtype:
return "mxfp4"
elif dtype[0] == "*":
return dtype[1:]
else:
return dtype
dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])])
layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])])
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]])
# mode = []
# if "GatherIndx" not in constants:
# mode += ['g']
# if "ScatterSrcIndx" not in constants:
# mode += ['s']
# suffix = "" if not mode else "_o" + (''.join(mode))
# if base_name.startswith("_p"):
# suffix += "_ptma"
return f"{base_name}_{layouts}_{dtypes}_{blocks}"
return matmul_repr
def matmul_launch_metadata(grid, kernel, args):
from ..proton_opts import launch_metadata_allow_sync
ret = dict()
M, N, K = args["M"], args["N"], args["K"]
Y, X, W = args["YPtr"], args["XPtr"], args["WPtr"]
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
hist = args["ExptHist"]
if hist is not None:
# If annotation is given, use that to generate name for profiling.
if tokens_per_expt is not None:
n_rows = f"{tokens_per_expt}*"
elif launch_metadata_allow_sync():
n_rows = int(hist.float().mean())
else:
n_rows = "unknown"
if launch_metadata_allow_sync():
n_tokens = float(hist.sum())
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
elif tokens_per_expt is not None:
n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
# This may not be totally correct (e.g., we might not be using all experts)
# but it's better than nothing.
n_w_bytes = W.numel() * W.element_size()
else:
n_tokens = None
n_w_bytes = 0
# If annotation is given, use that to generate name for profiling.
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
else:
n_tokens = None
n_w_bytes = W.numel() * W.element_size()
repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
nbits = X.dtype.itemsize * 8
batch_repr = ""
if "batch_size" in args and args["batch_size"] > 1:
batch_repr = repr("B", args["batch_size"]) + ", "
ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
ep_subtile = args["EPILOGUE_SUBTILE"]
if ep_subtile is not None and ep_subtile > 1:
ret["name"] += f" ep/{ep_subtile}"
if hist is not None and n_tokens is None:
return ret # Don't fill metadata because we can't compute them properly.
fM = M if M is not None else n_tokens
fK = K if K is not None else n_tokens
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
gindx = args.get("GatherIndx", None)
# sindx = args.get("WriteBackIndx", None)
n_x_bytes = X.numel() * X.element_size()
n_y_bytes = Y.numel() * Y.element_size()
if hist is not None:
assert n_tokens is not None
n_expts_act = args["N_EXPTS_ACT"]
if (gindx is not None) and launch_metadata_allow_sync():
# recreate inverse GatherIndx.
dst = torch.full_like(gindx, -1)
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
mask = (gindx != -1)
dst[gindx[mask]] = idx[mask]
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
else:
n_read_rows = n_tokens
n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
return ret

View File

@ -0,0 +1,433 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py
# Triton is licensed under the MIT License.
# isort: off
# fmt: off
import triton
import triton.language as tl
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
from triton_kernels.tensor_details.layout_details.cdna4_scale import unswizzle_mx_scale_cdna4
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@triton.jit
def _zero_masked_rows(
pid_m, pid_n,
Y, stride_y_m, stride_y_n,
N,
ScatterSrcIndx, num_idxs,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
mask_n = offs_n < N
mask = (src_idx == -1)[:, None] & mask_n[None, :]
tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
def _matmul_ogs(
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
YExpectedScale, YActualScale, YChecksumScale,
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
XScale,
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
WScale,
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
B, stride_b_e, # Bias
NRows, M, N, K, # shapes
# expt data
Betas, Gammas,
GatherIndx,
ScatterSrcIndx, num_idxs,
WriteBackIndx, writeback_size,
ExptHist, ExptOffs, ExptOffsSum, ExptData,
# true grid size
batch_size, grid_m, grid_n,
# Out scale
out_alpha,
# fused activation function
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
# MoE config
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
# precision config
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
FLEXPOINT_SATURATE_INF: tl.constexpr,
PER_BATCH_SCALE: tl.constexpr,
# optimization config
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_VALUE: tl.constexpr,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_SCALE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
W_CACHE_MODIFIER: tl.constexpr,
NUM_SMS: tl.constexpr,
X_TMA_MODE: tl.constexpr,
Y_TMA_MODE: tl.constexpr,
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
UPCAST_INDICES: tl.constexpr = False,
SWAP_XW: tl.constexpr = False,
IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
tl.assume(stride_y_k >= 0)
tl.assume(stride_y_z >= 0)
tl.assume(stride_y_m >= 0)
tl.assume(stride_y_n >= 0)
tl.assume(stride_x_z >= 0)
tl.assume(stride_x_m >= 0)
tl.assume(stride_x_k >= 0)
tl.assume(stride_w_e >= 0)
tl.assume(stride_w_k >= 0)
tl.assume(stride_w_n >= 0)
if stride_w_mx_e is not None:
tl.assume(stride_w_mx_e >= 0)
if stride_w_mx_k is not None:
tl.assume(stride_w_mx_k >= 0)
if stride_w_mx_n is not None:
tl.assume(stride_w_mx_n >= 0)
if B is not None:
tl.assume(stride_b_e >= 0)
tl.assume(batch_size >= 0)
tl.assume(grid_m >= 0)
tl.assume(grid_n >= 0)
is_w_microscaled: tl.constexpr = WMxScale is not None
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
if is_w_microscaled:
w_type: tl.constexpr = W.dtype.element_ty
is_mxfp4: tl.constexpr = w_type == tl.uint8
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
"mx_weight_ptr must be uint8 or fp8")
tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
else:
tl.static_assert(SWIZZLE_MX_VALUE is None)
tl.static_assert(SWIZZLE_MX_SCALE is None)
is_x_microscaled: tl.constexpr = XMxScale is not None
if is_x_microscaled:
x_type: tl.constexpr = X.dtype.element_ty
tl.static_assert(is_w_microscaled)
tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
yN = N // ACTIVATION_REDUCTION_N
pid = tl.program_id(0)
if ExptOffsSum is not None and XCD_SWIZZLE > 1:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m = grid_m - tl.load(ExptOffsSum)
else:
padding_m: tl.constexpr = 0
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
unpadded_m = grid_m - padding_m
tl.assume(unpadded_m >= 0)
total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
if padding_m > 0 and pid >= total_actual_tiles:
tl.device_assert(batch_size == 0)
pid_mn = pid - total_actual_tiles
if pid_mn < padding_m * grid_n:
pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
# set masked out rows to 0
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
return
# swizzle program ids
pid_emnk = pid
if XCD_SWIZZLE != 1:
pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K)
pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K)
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
pid_m, pid_n = swizzle2d(pid_mn, unpadded_m, grid_n, GROUP_M)
# For split-k, advance to the output k slice
if SPLIT_K > 1:
Y += pid_k.to( index_type) * stride_y_k
if is_out_microscaled:
YActualScale += pid_k.to(index_type) * stride_x_mx_k
# set masked out rows to 0
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
# unpack expert data
if ExptData is None:
tl.static_assert(M is not None)
expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
else:
tl.static_assert(M is None)
expt_data = tl.load(ExptData + pid_m)
if expt_data == -1:
return
expt_id = expt_data & 0x0000FFFF
block_id = expt_data >> 16
M = tl.load(ExptHist + expt_id)
start_m = tl.load(ExptOffs + expt_id)
start_z = 0
expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
start_m, start_z = start_m.to(index_type), start_z.to(index_type)
pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
# A pointers
offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
X += start_z * stride_x_z
if GatherIndx is None:
X += start_m * stride_x_m
else:
GatherIndx += start_m
# no needs to bounds-check here because `offs_x_m` wraps around M dim
offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
# TODO: refactor if/else when triton front end improves
if is_w_microscaled:
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
tl.static_assert(not is_x_microscaled)
# We have pack 2 fp4 values in a byte but we divide the dimension by 2
# when swizzling
W_K_DIVISOR: tl.constexpr = 1
W_K_MULTIPLIER: tl.constexpr = 2
W_N_DIVISOR: tl.constexpr = 4
else:
# We have pack 2 fp4 values in a byte
W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
W_K_MULTIPLIER: tl.constexpr = 1
W_N_DIVISOR: tl.constexpr = 1
if W_TRANSPOSE:
# When weight is transposed, 2 fp4 values are packed per Byte along
# the contiguous dimension, K.
PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
else:
# When weight is not transposed, fp4 values are *not* packed along
# the contiguous dimension, N.
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_K_DIVISOR
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
WMxScale += expt_id * stride_w_mx_e
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
# TODO: support non W_TRANSPOSE with blackwell swizzling
tl.static_assert(W_TRANSPOSE)
tl.static_assert(BLOCK_N % 128 == 0)
tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
stride_scale_k: tl.constexpr = 1
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
# TODO: support non W_TRANSPOSE with Hopper swizzling
tl.static_assert(W_TRANSPOSE)
n_warps: tl.constexpr = tl.extra.cuda.num_warps()
tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
stride_scale_k = stride_w_mx_k
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
tl.static_assert(stride_w_mx_k is not None)
tl.static_assert(stride_w_mx_n is not None)
NON_K_PRESHUFFLE_BLOCK_SIZE: tl.constexpr = 32
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * NON_K_PRESHUFFLE_BLOCK_SIZE
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // NON_K_PRESHUFFLE_BLOCK_SIZE
stride_scale_k = stride_w_mx_k
else:
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
SCALE_BLOCK_N: tl.constexpr = BLOCK_N
stride_scale_k = stride_w_mx_k
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
# K dimension must be the last dimension for the scales
offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
else:
WMxScalePtrs = None
offs_k_scale = None
W_K_DIVISOR: tl.constexpr = 1
W_K_MULTIPLIER: tl.constexpr = 1
W_N_DIVISOR: tl.constexpr = 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
# B pointers
offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
if is_x_microscaled:
XMxScale += start_z.to(index_type) * stride_x_mx_z
if GatherIndx is None:
XMxScale += start_m * stride_x_mx_m
offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
else:
XMxScalePtrs = None
offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
W += expt_id * stride_w_e
WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
# compute output
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
if EVEN_K:
mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
if is_x_microscaled:
mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
else:
mask_k = offs_k < k
mask_k_w = offs_w_k < ((k // (W_K_DIVISOR if W_TRANSPOSE else 1)) * W_K_MULTIPLIER)
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
if is_x_microscaled:
mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
if is_w_microscaled:
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
if is_x_microscaled:
x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
elif x_format == "fp16" or x_format == "bf16":
x_scales: tl.constexpr = None
else:
# Scale of 1 in E8M0 format
x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
# Handshake with the swizzling code
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
elif SWIZZLE_MX_SCALE == "CDNA4_SCALE":
w_scales = unswizzle_mx_scale_cdna4(tl.load(WMxScalePtrs), BLOCK_N, MX_SCALE_BLOCK_K)
else:
w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
# Handshake with the swizzling code
tl.static_assert(x_format == "bf16")
tl.static_assert(w_format == "e2m1")
w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
tl.static_assert(w.dtype == tl.bfloat16)
acc = acc.trans()
x = x.trans()
# w = w.trans()
acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
acc = acc.trans()
else:
rhs_k_pack: tl.constexpr = W_TRANSPOSE or not is_w_microscaled or W_K_DIVISOR != 2
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True, rhs_k_pack=rhs_k_pack)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
else:
WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
if is_x_microscaled:
XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
else:
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
# bias + scale
offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
mask_m = offs_m < M
mask_n = offs_y_n < N
if B is not None:
BPtrs = B + expt_id * stride_b_e + offs_y_n
if pid_k == 0:
bias = tl.load(BPtrs, mask=mask_n, other=0)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
if Betas is not None:
betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
else:
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
if Gammas is not None:
gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
else:
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
# flexpoint
x_scale = load_scale(XScale)
if PER_BATCH_SCALE:
w_scale = load_scale(WScale + expt_id)
else:
w_scale = load_scale(WScale)
acc *= x_scale * w_scale
acc = acc + bias[None, :] * betas[:, None]
if out_alpha is not None:
acc *= out_alpha
if ACTIVATION_FN is not None:
out = ACTIVATION_FN(acc, *activation_fn_args)
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
mask_n = offs_y_n < yN
else:
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
out = acc
out *= gammas[:, None]
# write-back
Y += start_z.to(index_type) * stride_y_z
if WriteBackIndx is not None:
WriteBackIndx += start_m
dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
mask_m = mask_m & (dst_idx != -1)
offs_y_m = dst_idx
else:
Y += start_m * stride_y_m
offs_y_m = offs_m
YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
mask = mask_m[:, None] & mask_n[None, :]
if is_out_microscaled:
MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
tl.static_assert(EPILOGUE_FN is not None)
out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
mask_n_scale = offs_y_n_scale < N_MX_BLOCK
YActualScale += start_z.to(index_type) * stride_y_mx_z
if WriteBackIndx is None:
YActualScale += start_m * stride_y_mx_m
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
else:
YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
else:
out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
if EPILOGUE_FN is not None and not IS_EPILOGUE_QUANT_MXFP8:
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
tl.store(YPtrs, out, mask=mask)

View File

@ -0,0 +1,475 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
# Triton is licensed under the MIT License.
# isort: off
# fmt: off
import torch
import triton
import triton.language as tl
from triton.tools.ragged_tma import load_ragged, store_ragged
from triton_kernels import target_info
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
from triton_kernels.numerics_details.flexpoint import (
float_to_flex,
load_scale,
nan_propagating_absmax_reduce,
compute_scale,
)
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
@triton.constexpr_function
def cuda_capability_geq(major, minor):
return target_info.cuda_capability_geq(major, minor)
@triton.constexpr_function
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
if isinstance(tensor_or_desc, tl.tensor):
return tensor_or_desc.dtype.element_ty
elif isinstance(tensor_or_desc, tl.tensor_descriptor):
return tensor_or_desc.dtype
else:
raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
@triton.jit
def _load_tile_attrs(
tile_id, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
# unpack and swizzle program ids
pid_emnk = tile_id
if XCD_SWIZZLE != 1:
pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
if SPLIT_K > 1:
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
else:
pid_k: tl.constexpr = 0
pid_mn = pid_mnk
pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
# unpack expert data
if ExptData is None:
tl.static_assert(M is not None)
expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
else:
tl.static_assert(M is None)
expt_data = tl.load(ExptData + pid_m)
expt_id = expt_data & 0x0000FFFF
block_id = expt_data >> 16
eM = tl.load(ExptHist + expt_id)
start_m = tl.load(ExptOffs + expt_id)
start_z = 0
off_m = BLOCK_M * block_id
off_n = BLOCK_N * pid_n
return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
@triton.jit
def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
mask = mask & (offs < writeback_size)
offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
mask = offs != -1
return (offs, mask)
_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
def _p_matmul_ogs(
Y, YPtr, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
YExpectedScale, YActualScale, YChecksumScale,
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
XScale,
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
W, WPtr, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
WScale,
MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
B, stride_b_e, # Bias
NRows, M, N, K, # shapes
# expt data
Betas, Gammas,
GatherIndx,
ScatterSrcIndx, num_idxs,
WriteBackIndx, writeback_size,
ExptHist, ExptOffs, ExptOffsSum, ExptData,
# true grid size
batch_size, grid_m, grid_n,
# Out scale
out_alpha,
# fused activation function
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
# MoE config
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
# precision config
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
FLEXPOINT_SATURATE_INF: tl.constexpr,
PER_BATCH_SCALE: tl.constexpr,
# optimization config
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
# NYI: Must be None
SWIZZLE_MX_VALUE: tl.constexpr,
# One of ["BLACKWELL", None]
SWIZZLE_MX_SCALE: tl.constexpr,
EPILOGUE_SUBTILE: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
W_CACHE_MODIFIER: tl.constexpr,
NUM_SMS: tl.constexpr,
X_TMA_MODE: tl.constexpr,
Y_TMA_MODE: tl.constexpr,
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
UPCAST_INDICES:tl.constexpr=False,
SWAP_XW: tl.constexpr = False,
IS_EPILOGUE_QUANT_MXFP8: tl.constexpr = False):
# tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
# why is this faster than using host-side tensor descriptor?!
if Y_TMA_MODE is not None:
Y = tl.make_tensor_descriptor(YPtr, Y.shape, Y.strides[:-1] + (1,), Y.block_shape)
is_microscaled_format: tl.constexpr = MxScale is not None
tl.static_assert(not is_microscaled_format or W_TRANSPOSE, "NYI. Non-transposed mxfp4 weights")
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
if is_microscaled_format:
w_type: tl.constexpr = get_dtype(W)
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
"mx_weight_ptr must be uint8")
tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
# We have pack 2 fp4 values in a byte
W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
else:
W_PACK_DIVISOR: tl.constexpr = 1
MX_SCALE_BLOCK_K: tl.constexpr = 1
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
tl.static_assert(SWIZZLE_MX_SCALE is None)
if ExptOffsSum is not None:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m = grid_m - tl.load(ExptOffsSum)
else:
padding_m: tl.constexpr = 0
index_type: tl.constexpr = tl.int64
USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
HAS_SCATTER: tl.constexpr = WriteBackIndx is not None
HAS_GATHER: tl.constexpr = GatherIndx is not None
USE_GATHER_TMA: tl.constexpr = HAS_GATHER and X_TMA_MODE == "dense"
USE_SCATTER_TMA: tl.constexpr = HAS_SCATTER and Y_TMA_MODE == "dense"
if EPILOGUE_SUBTILE is None:
SUBTILE_FACTOR: tl.constexpr = 1
else:
SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
yN = N // ACTIVATION_REDUCTION_N
# set masked out rows to 0
if HAS_SCATTER and N_EXPTS_ACT == 1:
# Iterate with reversed pids so that later pids will get more tiles if the number of
# tiles isn't evenly divisible by the number of SMs.
# The main loop after this iterates in the forward direction such that earlier
# pids get more tiles if the number of tiles isn't evenly divisible.
# This helps balance the work across the SMs.
for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
pid_k = pid_mnk % SPLIT_K
pid_mn = pid_mnk // SPLIT_K
pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
YPtrs = YPtr + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
mask_n = offs_n < yN
mask = (src_idx == -1)[:, None] & mask_n[None, :]
tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
# If true, do not share loop-carried variables between the prologue and the
# epilogue to enable better pipelining with mmav5
INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
# start negative; will be incremented at the top of the loop
if INDEPENDENT_EPILOGUE:
tile_id1 = tl.program_id(0) - NUM_SMS
# Keep track of local max for updating flexpoint scales.
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=True):
expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
tile_id, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M, BLOCK_N, SPLIT_K,
GROUP_M, XCD_SWIZZLE)
# Base pointers and offsets.
if X_TMA_MODE is None:
XBase = X + start_z.to(index_type) * stride_x_z
offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
if SPLIT_K > 1:
offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
if USE_GATHER_TMA:
offs_m = off_m + tl.arange(0, BLOCK_M)
mask_m = offs_m < (M if M is not None else eM)
if ExptData is None:
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
# Bump rows to account for the Z offset.
offs_x_m += start_z * (stride_x_z // stride_x_m)
offs_x_m = tl.where(mask_m, offs_x_m, -1)
else:
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
elif X_TMA_MODE is None:
tl.static_assert(HAS_GATHER)
offs_m = off_m + tl.arange(0, BLOCK_M)
if M is not None:
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
else:
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
# no needs to bounds-check here because `offs_m` wraps around M dim
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
# --- load x ---
if USE_GATHER_TMA:
x = X.gather(offs_x_m, off_k)
elif X_TMA_MODE == "dense":
x = X.load([start_z, start_m + off_m, off_k])
x = x.reshape(BLOCK_M, BLOCK_K)
elif X_TMA_MODE == "ragged":
x = load_ragged(X, start_m, eM, [start_z, off_m, off_k], ragged_dim=1)
x = x.reshape(BLOCK_M, BLOCK_K)
else:
tl.static_assert(X_TMA_MODE is None)
XPtrs = XBase + offs_x_m + offs_x_k
XBase += BLOCK_K * SPLIT_K * stride_x_k
mask_k = tl.arange(0, BLOCK_K) < K - off_k
if EVEN_K:
if SPLIT_K > 1:
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
else:
x = tl.load(XPtrs)
else:
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
# --- load w ---
if W_TRANSPOSE:
w = tl.reshape(W.load([expt_id, off_n, off_k_w]), W.block_shape[1:]).T
else:
w = tl.reshape(W.load([expt_id, off_k_w, off_n]), W.block_shape[1:])
# --- load w_scale ---
if is_microscaled_format:
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
if x_format == "fp16" or x_format == "bf16":
x_scales: tl.constexpr = None
else:
x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
w_scales = unswizzle_mx_scale_bw(w_scales)
else:
w_scales = MxScale.load([expt_id, off_k_mx, off_n])
w_scales = tl.reshape(w_scales, *w_scales.shape[1:]).T
# --- update accumulator ---
if is_microscaled_format:
if SWAP_XW:
acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
else:
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
else:
if SWAP_XW:
acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
else:
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
if INDEPENDENT_EPILOGUE:
tile_id1 += NUM_SMS
expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
tile_id1, num_tiles, grid_m, grid_n, padding_m,
M, ExptData, ExptHist, ExptOffs,
BLOCK_M, BLOCK_N, SPLIT_K,
GROUP_M, XCD_SWIZZLE)
else:
tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
offs_m = off_m1 + tl.arange(0, BLOCK_M)
mask_m = offs_m < (M if M is not None else eM1)
if USE_SCATTER_TMA:
offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
if SPLIT_K > 1:
# Compute the split k offset in number of rows, and add it to offs_y_m.
# This allows us to write to the correct slice in the output tensor while using
# a 2D TMA scatter.
tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
elif Y_TMA_MODE is None:
tl.static_assert(HAS_SCATTER)
offs_y_m, mask_m = _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
else:
offs_y_m = start_m1 + offs_m
MASK_ACC = False if USE_GATHER_TMA else USE_FLEXPOINT_SCALE
# bias + scale
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
mask_n = offs_y_n < N
if B is not None:
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
if pid_k1 == 0:
bias = tl.load(BPtrs, mask=mask_n, other=0)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
else:
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
if Betas is not None:
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
else:
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
if Gammas is not None:
gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
else:
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
x_scale = load_scale(XScale)
if PER_BATCH_SCALE:
w_scale = load_scale(WScale + expt_id1)
else:
w_scale = load_scale(WScale)
accs = (acc,)
biases = (bias,)
if SUBTILE_FACTOR >= 2:
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
accs = (acc0, acc1)
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
biases = (bias0, bias1)
if SUBTILE_FACTOR >= 4:
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
accs = (acc00, acc01, acc10, acc11)
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
biases = (bias00, bias01, bias10, bias11)
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
tl.static_assert(len(accs) == SUBTILE_FACTOR)
for a_i in tl.static_range(len(accs)):
acc_tile = accs[a_i]
acc_tile *= x_scale * w_scale
if SWAP_XW:
acc_tile = acc_tile.T
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
if out_alpha is not None:
acc_tile *= out_alpha
if ACTIVATION_FN is not None:
out = ACTIVATION_FN(acc_tile, *activation_fn_args)
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
else:
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
out = acc_tile
out *= gammas[:, None]
if MASK_ACC:
out = tl.where(mask_m[:, None], out, 0.0)
# Flexpoint
out_view = tl.reshape(out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
out = float_to_flex(
out, YExpectedScale,
None, # ActualScale: local absmax is tracked and updated after the loop
YChecksumScale,
None, # mask: out is manually masked to 0
YPtr, FLEXPOINT_SATURATE_INF
)
if EPILOGUE_FN is not None:
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtr.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
out = out.to(YPtr.dtype.element_ty)
if USE_SCATTER_TMA:
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
# there shouldn't be any other negative values.
offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
Y.scatter(out, offs_y_m, out_off_n)
elif Y_TMA_MODE == "dense":
out = tl.reshape(out, [1] + out.shape)
off_kz = pid_k * batch_size + start_z1
Y.store([off_kz, off_m1, out_off_n], out)
elif Y_TMA_MODE == "ragged":
out = tl.reshape(out, [1] + out.shape)
store_ragged(Y, start_m1, eM1, [pid_k, off_m1, out_off_n], out, ragged_dim=1)
else:
tl.static_assert(Y_TMA_MODE is None)
offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
mask_n = offs_y_n < yN
YPtrs = YPtr + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
mask = mask_m[:, None] & mask_n[None, :]
tl.store(YPtrs, out, mask=mask)
# Update the flexpoint scales
if YActualScale is not None:
tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), YPtr), sem="relaxed")
_per_device_alloc_fns = {}
def get_per_device_per_stream_alloc_fn(device):
if device not in _per_device_alloc_fns:
_per_stream_tensors = {}
def alloc_fn(size: int, alignment: int, stream):
assert alignment == 128
if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
_per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
_per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
return _per_stream_tensors[stream]
_per_device_alloc_fns[device] = alloc_fn
return _per_device_alloc_fns[device]

View File

@ -0,0 +1,98 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py
# Triton is licensed under the MIT License.
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
from triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
import triton
import triton.language as tl
@triton.jit
def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, #
XScale, # input scalar flex scale
Out, stride_om: tl.uint64, stride_on, # output tensor
OutExpectedScale, OutActualScale, OutChecksumScale, # output scalar flex scales
InIndx, B, N, #
XMxScale, stride_mxb: tl.uint64,
stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
OutMxScale, stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
# fused activation function
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
# epilogue transform
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
#
HAS_IN_MX_SCALE: tl.constexpr, HAS_OUT_MX_SCALE: tl.constexpr, FLEXPOINT_SATURATE_INF: tl.constexpr,
K: tl.constexpr, BLOCK_N: tl.constexpr):
pid_t = tl.program_id(0)
BLOCK_N_OUT: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
# persistent along N: single program on N, iterate tiles of size BLOCK_N
start = pid_t * K
# load indices into a tuple
if InIndx is None:
indxs = (pid_t, )
else:
indxs = ()
for i in tl.static_range(0, K):
indxs = indxs + (tl.load(InIndx + start + i), )
# determine first valid topk row
fi = indxs[(K - 1)]
for i in tl.static_range(K - 2, -1, -1):
fi = tl.where(indxs[i] != -1, indxs[i], fi)
# record overwritten row index (may be -1 if none)
XPtrs = X + tl.arange(0, BLOCK_N) * stride_xn
OutPtrs = Out + tl.arange(0, BLOCK_N_OUT) * stride_on
if HAS_IN_MX_SCALE:
XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
if HAS_OUT_MX_SCALE:
OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
x_scale = load_scale(XScale)
for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)
x_n_mask = tl.arange(0, BLOCK_N) < N - n_curr
x_n_mask_scale = tl.arange(0, BLOCK_N // 32) < tl.cdiv(N - n_curr, 32)
# accumulate contributions for this tile
for i in tl.static_range(0, K):
curr = tl.zeros([BLOCK_N], dtype=tl.float32)
# iterate over split_k partial values
for b in tl.range(0, B):
is_valid = indxs[i] != -1
x_row_ptr = XPtrs + indxs[i] * stride_xm + b * stride_xb
vals = tl.load(x_row_ptr, mask=x_n_mask & is_valid, other=0.0)
vals = vals.to(tl.float32)
if HAS_IN_MX_SCALE:
scale_row_ptr = XScalePtrs + indxs[i] * stride_mxs + b * stride_mxb
scale = tl.load(scale_row_ptr, mask=x_n_mask_scale & is_valid, other=0.)
scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
vals = vals.reshape([BLOCK_N // 32, 32])
vals = (scale[:, None] * vals).reshape([BLOCK_N])
curr += vals
# apply nonlinearity to split-k output
if ACTIVATION_FN is not None:
curr = ACTIVATION_FN(curr[None, :], *activation_fn_args)
curr = tl.reshape(curr, [curr.shape[-1]])
# update final accumulator
acc += curr
acc *= x_scale
# Compute per-32-col MXFP scales for this tile if requested
Nrem = (N - n_curr) // ACTIVATION_REDUCTION_N
out_n_mask = tl.arange(0, BLOCK_N_OUT) < Nrem
out_n_mask_scale = tl.arange(0, BLOCK_N_OUT // 32) < tl.cdiv(Nrem, 32)
if HAS_OUT_MX_SCALE:
acc, acc_scale = quantize_mxfp8_fn(acc[None, :], out_n_mask[None, :])
acc = tl.reshape(acc, [acc.shape[-1]])
acc_scale = tl.reshape(acc_scale, [acc_scale.shape[-1]])
# Convert to flexpoint output if configured (scalar scales)
acc = float_to_flex(acc, OutExpectedScale, OutActualScale, OutChecksumScale, None, Out, FLEXPOINT_SATURATE_INF)
# write-back for this tile
out_ptr = OutPtrs + pid_t * stride_om
tl.store(out_ptr, acc, mask=out_n_mask)
if HAS_OUT_MX_SCALE:
out_scale_ptr = OutScalePtrs + pid_t * stride_omxs
tl.store(out_scale_ptr, acc_scale, mask=out_n_mask_scale)
XPtrs += BLOCK_N * stride_xn
OutPtrs += BLOCK_N_OUT * stride_on
if HAS_IN_MX_SCALE:
XScalePtrs += BLOCK_N // 32 * stride_xn
if HAS_OUT_MX_SCALE:
OutScalePtrs += BLOCK_N_OUT // 32 * stride_xn

View File

@ -0,0 +1,307 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py
# Triton is licensed under the MIT License.
# isort: off
# fmt: off
from dataclasses import dataclass
import triton
from triton_kernels.target_info import get_cdna_version
import torch
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
@dataclass
class OptFlags:
block_m: int
block_n: int
block_k: int
num_warps: int
num_stages: int
group_m: int
xcd_swizzle: int
w_cache_modifier: str
split_k: int
is_persistent: bool
fused_scatter: bool
idle_sms: int
epilogue_subtile: int | None
arch: str
target_kernel_kwargs: dict
def __post_init__(self):
if self.fused_scatter and self.split_k != 1:
raise ValueError("Not supported")
def make_default_opt_flags_amd(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
enforce_bitwise_invariance,
epilogue_effective_itemsize,
constraints,
):
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None:
tokens_per_expt = m
elif routing_data.expected_tokens_per_expt is None:
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
else:
tokens_per_expt = routing_data.expected_tokens_per_expt
is_cdna4 = get_cdna_version() == 4
# block_m
if constraints.get("block_m", None):
block_m = constraints["block_m"]
elif enforce_bitwise_invariance:
block_m = 256 if is_cdna4 else 128
elif tokens_per_expt >= 512 and n >= 2048:
block_m = 256 if is_cdna4 else 128
elif is_cdna4 and m >= 512:
block_m = 128
else:
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
if routing_data is not None:
grid_m = routing_data.n_blocks(m, block_m)
else:
grid_m = triton.cdiv(m, block_m)
# group_m:
group_m = 4
# number of xcds
num_xcds = 8
xcd_swizzle = num_xcds
# block_nk:
block_n, block_k = opt_flags_amd.compute_block_nk(
n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
)
# Replace block_k if provided in constraints.
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
if constraints.get("block_k", None) is not None:
block_k = constraints["block_k"]
if constraints.get("block_n", None) is not None:
block_n = constraints["block_n"]
is_persistent = constraints.get("is_persistent", False)
# split_k:
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance:
split_k = 1
else:
grid_size = grid_m * ((n + block_n - 1) // block_n)
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
split_k = max(1, n_cu // grid_size)
# w_cache_modifier:
w_cache_modifier = ".cg" if block_m <= 32 else None
# num_warps, num_stages
num_warps = 2 if (m is not None and m <= 16) else 8
num_stages = 2
# AMD-specific
target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
epilogue_subtile = constraints.get('epilogue_subtile', None)
if epilogue_subtile is None:
epilogue_subtile = 1
ret = OptFlags(
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
group_m=group_m,
xcd_swizzle=xcd_swizzle,
w_cache_modifier=w_cache_modifier,
split_k=split_k,
is_persistent=is_persistent,
fused_scatter=constraints.get('fused_scatter', False),
idle_sms=0,
epilogue_subtile=epilogue_subtile,
arch=None,
target_kernel_kwargs=target_kernel_kwargs,
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
return ret
def make_default_opt_flags_nvidia(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
enforce_bitwise_invariance,
epilogue_effective_itemsize,
constraints,
):
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
# tokens per expert
if routing_data is None:
tokens_per_expt = m
elif routing_data.expected_tokens_per_expt is None:
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
else:
tokens_per_expt = routing_data.expected_tokens_per_expt
# pid swizzling
group_m = 8
xcd_swizzle = 1
# block_m
if constraints.get("block_m", None):
block_m = constraints["block_m"]
elif enforce_bitwise_invariance:
block_m = 128
else:
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
# block n
arch = None
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
# is_persistent
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
tiles_per_sm = grid_size / n_sms
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
if constraints.get("is_persistent", None) is not None:
is_persistent = constraints["is_persistent"]
else:
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
# TEMP CHANGE
if precision_config.act_scale is not None or precision_config.out_scale is not None:
is_persistent = False
# block k
if constraints.get("block_k", None) is not None:
block_k = constraints["block_k"]
else:
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
# split_k
if constraints.get("split_k", None) is not None:
split_k = constraints["split_k"]
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
split_k = 1
else:
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
if split_k > 1:
# With split_k, results are written in f32. Use that for the following computations.
out_dtype = torch.float32
compute_num_stages_args = (
precision_config,
is_persistent,
block_m,
block_n,
block_k,
out_dtype,
lhs_dtype,
rhs_dtype,
)
if constraints.get("epilogue_subtile", None) is not None:
subtiles_to_check = [constraints["epilogue_subtile"]]
else:
subtiles_to_check = [1, 2, 4]
num_stages = -1
for ep in subtiles_to_check:
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
if ns > num_stages:
epilogue_subtile, num_stages = ep, ns
assert num_stages >= 1
if constraints.get("num_stages", None):
num_stages = constraints["num_stages"]
# fused scatter scratchpad
if constraints.get("fused_scatter", None) is not None:
fused_scatter = constraints["fused_scatter"]
else:
fused_scatter = can_use_fused_scatter and split_k == 1
# Handshake with the HBM swizzling
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
ret = OptFlags(
block_m=block_m,
block_n=block_n,
block_k=block_k,
num_warps=num_warps,
num_stages=num_stages,
fused_scatter=fused_scatter,
group_m=group_m,
xcd_swizzle=xcd_swizzle,
w_cache_modifier=None,
split_k=split_k,
is_persistent=is_persistent,
epilogue_subtile=epilogue_subtile,
arch=arch,
target_kernel_kwargs=dict(),
idle_sms=constraints.get("idle_sms", 0),
)
# check constraints
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
return ret
# --------------
# User Interface
# --------------
_opt_flags_constraints: dict = dict()
_opt_flags: OptFlags | None = None
def update_opt_flags_constraints(constraints: dict[str, int]):
global _opt_flags_constraints
_opt_flags_constraints.update(constraints)
def reset_opt_flags_constraints():
global _opt_flags_constraints
_opt_flags_constraints = dict()
def set_opt_flags(opt_flags: OptFlags):
global _opt_flags
assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
assert not _opt_flags, "opt_flags already set; please reset to None first"
_opt_flags = opt_flags
class InapplicableConstraint(Exception):
pass
def make_opt_flags(
out_dtype,
lhs_dtype,
rhs_dtype,
precision_config,
m,
n,
k,
routing_data,
can_use_persistent_tma,
can_use_fused_scatter,
epilogue_effective_itemsize,
):
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
if _opt_flags is not None:
assert not _opt_flags_constraints
return _opt_flags
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
routing_data, can_use_persistent_tma, can_use_fused_scatter,
enforce_bitwise_invariance, epilogue_effective_itemsize,
_opt_flags_constraints]
backend = triton.runtime.driver.active.get_current_target().backend
if backend == "hip":
return make_default_opt_flags_amd(*args)
if backend == "cuda":
return make_default_opt_flags_nvidia(*args)
assert False

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,37 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
# Triton is licensed under the MIT License.
import torch
import triton
from triton_kernels.target_info import get_cdna_version
from triton_kernels.tensor import bitwidth
def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
lhs_width = bitwidth(lhs_dtype) / 8
rhs_width = bitwidth(rhs_dtype) / 8
# block_n:
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
if n is not None:
if n <= 128 and (n & (n - 1)) == 0:
block_n = n
else:
block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu)))
elif block_m > 64:
block_n = 256
else:
block_n = 128
if get_cdna_version() == 4 and block_m == 128:
block_n = 512
# block_k needs to match the cacheline size (128B)
block_k = int(128 // min(lhs_width, rhs_width))
# TODO: block_k = 128 seems to work better for now.
# perhaps due to increased number of k loops to pipeline
if precision_config.weight_scale is not None and get_cdna_version() != 4:
block_k = 128
return block_n, block_k

View File

@ -0,0 +1,115 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
# Triton is licensed under the MIT License.
import torch
import triton
from triton_kernels import target_info
from triton_kernels.tensor import get_layout, bitwidth, FP4
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
def compute_grid_size(routing_data, m, n, block_m, block_n):
if routing_data is not None:
grid_m = routing_data.n_blocks(m, block_m)
else:
grid_m = triton.cdiv(m, block_m)
grid_n = (n + block_n - 1) // block_n
return grid_m * grid_n
def compute_block_n(n: int, arch, precision_config):
# block_n:
layout = get_layout(precision_config.weight_scale)
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
return 128
elif precision_config.max_num_imprecise_acc is None and n > 128:
return 256
else:
return max(16, min(128, triton.next_power_of_2(n)))
def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
lhs_width = bitwidth(lhs_dtype)
rhs_width = bitwidth(rhs_dtype)
# block_k needs to match the cacheline size (1024 bits)
block_k = int(1024 // min(lhs_width, rhs_width))
has_native_mxfp = target_info.cuda_capability_geq(10, 0)
if rhs_width == 4 and not has_native_mxfp:
block_k = 128
elif k is not None:
block_k = max(32, min(triton.next_power_of_2(k), block_k))
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
if has_native_mxfp and is_persistent and has_mx_weight_scale:
block_k = min(block_k, 128)
return block_k
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
device_props = torch.cuda.get_device_properties(0)
n_sms = device_props.multi_processor_count
split_k = n_sms // grid_size
if k is not None:
# avoid split_k for small k
num_block_k = triton.cdiv(k, block_k)
split_k = min(split_k, num_block_k // 4)
split_k = max(split_k, 1)
return split_k
def compute_num_warps(block_m, block_n, precision_config):
layout = get_layout(precision_config.weight_scale)
if isinstance(layout, HopperMXScaleLayout):
return layout.num_warps
return max(block_m * block_n // 4096, 4)
def compute_num_stages(
precision_config,
is_persistent,
block_m,
block_n,
block_k,
out_dtype,
lhs_dtype,
rhs_dtype,
epilogue_subtile,
epilogue_effective_itemsize,
):
if precision_config.max_num_imprecise_acc is not None:
return 3
weight_size = bitwidth(rhs_dtype) / 8
stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
device_props = torch.cuda.get_device_properties(0)
smem_capacity = device_props.shared_memory_per_block_optin
has_native_mxfp = target_info.cuda_capability_geq(10, 0)
if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
if rhs_dtype == FP4:
# 4-bit e2m1 weights are padded 2x
# https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
stage_size += block_k * block_n * weight_size
if is_persistent:
# Per-stage wait barrier
stage_size += 8
if target_info.cuda_capability_geq(10, 0):
acc_size = epilogue_effective_itemsize or out_dtype.itemsize
else:
acc_size = out_dtype.itemsize
if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
acc_block_n = block_n // epilogue_subtile
else:
acc_block_n = block_n
# pipelined TMA store local to global, or
# pipelined layout conversion before store of the accumulator
# note: layout conversion has some padding
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
if precision_config.weight_scale is not None:
# mx scales
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
elif has_native_mxfp:
# mx scales
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
num_stages = min(4, smem_capacity // int(stage_size))
return num_stages

View File

@ -0,0 +1,46 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics.py
# Triton is licensed under the MIT License.
import torch
from dataclasses import dataclass
MAX_FINITE_FLOAT8E5 = 57344.0
MAX_FINITE_FLOAT8E4NV = 448.0
MAX_FINITE_FLOAT8E4B8 = 240.0
@dataclass(frozen=True)
class BaseFlexData:
dtype: torch.dtype | None = None
def view(self, x: torch.Tensor):
if self.dtype is None:
return x
return x.view(self.dtype)
def reinterpret(self, x):
if self.dtype is None or x.dtype.itemsize > 1:
return x
return x.view(self.dtype)
@dataclass(frozen=True)
class InFlexData(BaseFlexData):
scale: torch.Tensor | None = None
@property
def is_per_batch(self):
return False if self.scale is None else len(self.scale) > 1
@dataclass(frozen=True)
class OutFlexData(BaseFlexData):
expected_scale: torch.Tensor | None = None
actual_scale: torch.Tensor | None = None
checksum_scale: torch.Tensor | None = None
def __iter__(self):
yield self.expected_scale
yield self.actual_scale
yield self.checksum_scale

View File

@ -0,0 +1,3 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics_details/__init__.py
# Triton is licensed under the MIT License.

View File

@ -0,0 +1,194 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics_details/flexpoint.py
# Triton is licensed under the MIT License.
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
import triton
import triton.language as tl
from triton_kernels.target_info import cuda_capability_geq
# -------------------------------
# Kernels stuff
# -------------------------------
TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
@triton.jit
def max_finite(dtype):
if dtype == tl.constexpr(tl.float8e5):
return TL_MAX_FINITE_FLOAT8E5
elif dtype == tl.constexpr(tl.float8e4nv):
return TL_MAX_FINITE_FLOAT8E4NV
elif dtype == tl.constexpr(tl.float8e4b8):
return TL_MAX_FINITE_FLOAT8E4B8
elif dtype == tl.constexpr(tl.float8e4b15):
return TL_MAX_FINITE_FLOAT8E4B15
elif dtype == tl.constexpr(tl.float16):
return TL_MAX_FINITE_FLOAT16
else:
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
@triton.jit
def rcp_max_finite(dtype):
if dtype == tl.constexpr(tl.float8e5):
return TL_RCP_MAX_FINITE_FLOAT8E5
elif dtype == tl.constexpr(tl.float8e4nv):
return TL_RCP_MAX_FINITE_FLOAT8E4NV
elif dtype == tl.constexpr(tl.float8e4b8):
return TL_RCP_MAX_FINITE_FLOAT8E4B8
elif dtype == tl.constexpr(tl.float8e4b15):
return TL_RCP_MAX_FINITE_FLOAT8E4B15
elif dtype == tl.constexpr(tl.float16):
return TL_RCP_MAX_FINITE_FLOAT16
else:
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
@triton.jit
def sm86_min_nan_xorsign_abs_f32(a, b):
"""Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl.static_assert(cuda_capability_geq(8, 6), "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
tl.static_assert(a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
tl.static_assert(b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
return tl.inline_asm_elementwise(
"""{
min.NaN.xorsign.abs.f32 $0, $1, $2;
}""",
"=r,r,r",
[a, b],
dtype=tl.float32,
is_pure=True,
pack=1,
)
@triton.jit
def sm86_max_nan_xorsign_abs_f32(a, b):
"""Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl.static_assert(cuda_capability_geq(8, 6), "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
tl.static_assert(a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
tl.static_assert(b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
return tl.inline_asm_elementwise(
"""{
max.NaN.xorsign.abs.f32 $0, $1, $2;
}""",
"=r,r,r",
[a, b],
dtype=tl.float32,
is_pure=True,
pack=1,
)
@triton.jit
def load_scale(scale_ptr):
return 1.0 if scale_ptr is None else tl.load(scale_ptr)
@triton.jit
def flex_to_float(x, scale_ptr):
scale = load_scale(scale_ptr)
return x.to(tl.float32) * scale
@triton.jit
def clip(x, limit):
res = tl.minimum(x, limit)
res = tl.maximum(-limit, res)
return res
@triton.jit
def nan_propagating_absmax_reduce(x, axis=None):
if cuda_capability_geq(8, 6):
# abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
# Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
else:
# Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
x_absmax = tl.max(masked_abs_x, axis)
return x_absmax
@triton.jit
def compute_scale(x, Out):
x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
# atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
# We use integer minimum because NaNs are above +inf in integer representation.
x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
@triton.jit
def update_scale(x, scale_ptr, Out) -> None:
if scale_ptr is not None:
scale = compute_scale(x, Out)
tl.atomic_max(scale_ptr, scale, sem="relaxed")
@triton.jit
def float_to_flex(
x,
expected_scale_ptr_or_val,
actual_scale_ptr,
checksum_scale_ptr,
mask,
Out,
saturate_infs: tl.constexpr,
):
if expected_scale_ptr_or_val is not None:
if expected_scale_ptr_or_val.dtype.is_ptr():
invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
else:
invscale = 1.0 / expected_scale_ptr_or_val
else:
invscale = 1.0
if checksum_scale_ptr is not None:
x_int32 = x.to(tl.int32, bitcast=True)
zero = tl.cast(0.0, tl.int32)
if mask is not None:
x_int32 = tl.where(mask, x_int32, zero)
checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
tl.atomic_add(checksum_scale_ptr, checksum_local)
if mask is not None:
if actual_scale_ptr is not None:
x = tl.where(mask, x, 0.0)
update_scale(x, actual_scale_ptr, Out)
x = x * invscale
# if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
if expected_scale_ptr_or_val is not None:
if saturate_infs:
CLIP_VALUE = max_finite(Out.dtype.element_ty)
x = clip(x, CLIP_VALUE)
return x

View File

@ -0,0 +1,307 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics_details/mxfp.py
# Triton is licensed under the MIT License.
# isort: off
# fmt: off
from enum import Enum
import triton
import torch
import torch.nn.functional as F
from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE, _quantize_mxfp8_fn
# -----------------------------------------------------------------------------
# Dequantization / Quantization Utilities
# -----------------------------------------------------------------------------
class DequantScaleRoundingMode(Enum):
ROUND_UP = 0
ROUND_DOWN = 1
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
Note that this means the k_dim of the tensor will be half of the logical k_dim.
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
in their respective formats.
"""
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
# downcast
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
is_fp4 = out_quant_type == torch.uint8
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
assert is_fp4 or is_fp8
divisor = 2 if is_fp4 else 1
L = src_tensor.shape[-1]
if is_fp4:
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
out_shape = src_tensor.shape[:-1] + (L // divisor, )
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
if src_tensor.numel() > 0:
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
return out_quant_tensor, out_scale
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
"""
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
The function assumes that the tensors were quantized along the given axis.
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
launches the Triton upcast kernel, and then unpermutes back to the original order.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
axis = axis if axis >= 0 else axis + ndim
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
f"Got {tensor.ndim=} and {scale.ndim=}")
# dtype checks
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
f"Invalid tensor dtype {tensor.dtype=}"
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
assert target_dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {target_dtype=}"
# upcast
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=target_dtype, device=tensor.device)
reshaped_out = out.view(-1, out.shape[-1])
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
reshaped_scale = scale.view(-1, scale.shape[-1])
BLOCK_OUT_DIM = 128
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
*reshaped_scale.stride(), reshaped_tensor,
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
BLOCK_QUANT_DIM, num_warps=8)
out = out.transpose(axis, scale.ndim - 1).contiguous()
return out
# ------------
def right_shift_unsigned(x, shift):
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
return (x >> shift) & ((1 << (32 - shift)) - 1)
def get_max_quant_val(dtype: torch.dtype):
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
assert dtype in d
return d[dtype]
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
"""
Converts the src tensor to the output format specified by out_quant_type.
axis: The axis along which the tensors are contiguous and quantization is applied.
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
Returns:
out_quant_tensor: Quantized tensor in mx format.
For mxfp8, the output has the same shape as src_tensor.
For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
where L is the original length along that axis.
"""
# This should probably be packed into its own tiny class
ndim = src_tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
axis = axis if axis >= 0 else axis + ndim
is_fp4 = out_quant_type == torch.uint8
is_fp8 = "float8" in str(out_quant_type)
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
device = src_tensor.device
# For mxfp4 conversion, we assume the contiguous axis length is even.
if is_fp4:
axis_shape = src_tensor.size(axis)
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
# Permute the tensor so that the contiguous axis becomes the last dimension.
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
axis_shape = src.shape[-1]
# Pad the axis to be divisible by 32, in case it is not.
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_amount = next_multiple - axis_shape
padded_src = F.pad(src, (0, pad_amount))
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
padded_axis_shape = padded_src.size(-1) # now divisible by 32
# --- Compute per-group maximums for scale ---
# Set padded entries to -1 so they dont affect the max.
abs_f = torch.abs(padded_src)
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
# Reshape the last dimension into groups of 32.
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
abs_groups = abs_f.view(*new_shape)
# Compute maximum along the group dimension (of size 32).
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
# Choose a max quantization value depending on type.
max_quant_val = get_max_quant_val(out_quant_type)
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
# Convert to int to round the FP32 scale, prior to quantization!
ds_int = dequant_scale.view(torch.int32)
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
else:
ds_int_rounded = ds_int & 0x7F800000
# Reinterpret back as float32.
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
# Compute the quantization scale.
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
# Quantize the tensor
orig_padded_shape = padded_src.shape
padded_src_groups = padded_src.view(*new_shape)
quant_tensor = padded_src_groups * quant_scale
# Reshape back to the original shape and trim padding
quant_tensor = quant_tensor.view(orig_padded_shape)
quant_tensor = quant_tensor[..., :axis_shape]
# Finally, convert the quantized tensor to the target format
if is_fp8:
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
out_weight = quant_tensor.to(out_quant_type)
else:
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
# First, reinterpret the quantized tensor bits.
q_int = quant_tensor.contiguous().view(torch.int32)
# Extract sign, exponent, and mantissa.
signs = q_int & 0x80000000
exponents = right_shift_unsigned(q_int, 23) & 0xFF
mantissas = q_int & 0x7FFFFF
E8_BIAS = 127
E2_BIAS = 1
# Adjust mantissas for subnormals.
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
(E8_BIAS - exponents - 1), mantissas)
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
# Pack pairs of 4-bit values along the last dimension.
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
evens = e2m1_value[..., 0]
odds = e2m1_value[..., 1]
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
# --- Process and output the scale ---
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
dq_scale = dq_scale.squeeze(-1)
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
return out_weight, dq_scale
def cvt_e2m1_to_fp32(input_tensor):
assert input_tensor.dtype == torch.uint8
input_tensor = input_tensor.to(torch.int32)
evens = input_tensor & 0xF
odds = (input_tensor >> 4) & 0xF
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
outputs = torch.cat([outputs, -outputs])
even_floats = outputs[evens]
odd_floats = outputs[odds]
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
return output_tensor
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
"""
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
axis: The axis along which dequantization is applied.
Returns:
out_weight: Tensor in the target format.
"""
ndim = tensor.ndim
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
# Permute the tensor and scale so that the quantization axis becomes the last dimension
axis = axis if axis >= 0 else axis + ndim
scale = scale.transpose(axis, scale.ndim - 1)
tensor = tensor.transpose(axis, tensor.ndim - 1)
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
if tensor.dtype == torch.uint8:
fp32_tensor = cvt_e2m1_to_fp32(tensor)
else:
fp32_tensor = tensor.to(torch.float32)
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
axis_shape = fp32_tensor.size(-1)
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
pad_size = padded_axis_shape - axis_shape
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
new_axis_shape = padded_tensor.shape[-1]
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
padded_tensor = padded_tensor.view(*new_shape)
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
out_padded = padded_tensor * dq_scale_padded
# Flatten back and remove the padded tail
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
out_tensor = out_padded[..., :axis_shape]
out_tensor = out_tensor.to(target_dtype).contiguous()
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
return out_tensor
quantize_mxfp8_fn = _quantize_mxfp8_fn

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,162 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
# fmt: off
MXFP_BLOCK_SIZE = tl.constexpr(32)
@triton.jit
def _get_max_quant_val(dtype: tl.constexpr):
if dtype == tl.uint8:
return 6.0
elif dtype == tl.float8e5:
return 57344.0
elif dtype == tl.float8e4nv:
return 448.0
else:
tl.static_assert(False, f"Invalid {dtype=}")
@triton.jit
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor = src_tensor.to(tl.float32)
abs_tensor = tl.abs(f32_tensor)
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
if DEQUANT_SCALE_ROUNDING_MODE == 0:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
else:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert DEQUANT_SCALE_ROUNDING_MODE == 1
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
quant_tensor = f32_tensor * quant_scale
# Reshape the tensors after scaling
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
# Now we must convert the tensors to the mx format.
if is_fp8:
out_tensor = quant_tensor.to(mx_tensor_dtype)
else:
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
signs = quant_tensor & 0x80000000
exponents = (quant_tensor >> 23) & 0xFF
mantissas = (quant_tensor & 0x7FFFFF)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS = 127
E2_BIAS = 1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
evens, odds = tl.split(e2m1_value)
out_tensor = evens | (odds << 4)
return out_tensor, dequant_scale_exponent
@triton.jit
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
src_ptr, stride_src_outer, stride_src_quant,
outer_dim, quant_dim,
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16) or (src_dtype == tl.float32), f"{src_dtype=} must be bfloat16 or float16 or float32")
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out = outer_block * BLOCK_SIZE_OUT_DIM
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
mask_n = start_out + offs_outer < outer_dim
full_mask_src = mask_src_quant & mask_n
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_mxt = mask_mxt_quant & mask_n
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = scale_mask_k & mask_n
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
DEQUANT_SCALE_ROUNDING_MODE)
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
def _quantize_mxfp8_fn(input, mask, pid=None):
return _compute_quant_and_scale(input, mask, tl.float8e4nv)

View File

@ -0,0 +1,129 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
# fmt: off
@triton.jit
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32)
tl.static_assert(
mx_tensor_dtype == tl.uint8
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
# Determine if we are dealing with fp8 types.
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block = tl.program_id(0).to(tl.int64)
quant_block = tl.program_id(1).to(tl.int64)
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
start_out = outer_block * BLOCK_SIZE_OUT_DIM
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
# Compute offsets and masks.
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
mask_outer = start_out + offs_outer < outer_dim
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
full_mask_out = mask_out_quant & mask_outer
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
full_mask_src = mask_src_quant & mask_outer
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
full_scale_mask = mask_scale & mask_outer
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
# Load the packed tensor and scale.
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
# Upcast the scale to the destination type.
if dst_dtype == tl.bfloat16:
dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
else:
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
if dst_dtype == tl.float16:
dst_scale = dst_scale.to(tl.float16)
# Now upcast the tensor.
intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype
if is_fp8:
dst_tensor = tensor.to(intermediate_dtype)
if tensor.dtype == tl.float8e5:
from_e_bits: tl.constexpr = 5
from_m_bits: tl.constexpr = 2
to_e_bits: tl.constexpr = 8 if intermediate_dtype == tl.bfloat16 else 5
to_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
dst_tensor = tl.where(
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(intermediate_dtype, bitcast=True),
dst_tensor,
)
else:
assert is_fp4
dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15
dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800
dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10
# e2m1
em0 = tensor & 0x07
em1 = tensor & 0x70
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
# 3) x is zero, do nothing
dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True)
dst_tensor = dst_tensor.to(dst_dtype)
# Reshape for proper broadcasting: the scale was stored with a 32sized “inner” grouping.
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
scale = scale.reshape(dst_scale.shape)
out_tensor = dst_tensor * dst_scale
# Correct any NaNs encoded via the scale.
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)

View File

@ -0,0 +1,21 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/proton_opts.py
# Triton is licensed under the MIT License.
# proton options
import os
_launch_metadata_allow_sync = None
def launch_metadata_allow_sync():
global _launch_metadata_allow_sync
if _launch_metadata_allow_sync is None:
_launch_metadata_allow_sync = not (os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1")
return _launch_metadata_allow_sync
def set_launch_metadata_allow_sync(allow_sync: bool):
global _launch_metadata_allow_sync
_launch_metadata_allow_sync = allow_sync

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,115 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py
# Triton is licensed under the MIT License.
import torch
import triton
import triton.language as tl
@triton.jit
def vpopc(x):
"""
Vertical popcount
Input x : uint32[..., N]
Output y : uint32[..., 32]
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
credits: @apgoucher
"""
tl.static_assert(x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers")
BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
if BLOCK_N >= 8:
sa1: tl.constexpr = 8
else:
sa1: tl.constexpr = BLOCK_N
# create 8-way sums in 4-bit fields:
y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
if BLOCK_N >= 128:
sa2: tl.constexpr = 16
else:
sa2: tl.constexpr = BLOCK_N // sa1
# create 128-way sums in 8-bit fields:
y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0f0f0f0f
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
# create N-way sums in 32-bit fields:
y = tl.reshape(y, [BATCHES, 1, sa3, 8])
y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000ff
y = tl.sum(y, 2) # [BATCHES, 4, 8]
y = tl.reshape(y, x.shape[:-1] + [32])
return y
@triton.jit
def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
tl.store(Ret + offs, 0)
@triton.jit
def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix
Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs
BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr):
tl.static_assert(BLOCK_MM % BLOCK_M == 0)
TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
shape_bm = tl.load(shape_bm)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
offs_n = pid_n * 32 + tl.arange(0, 32)
n_rows = shape_bm
bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0)
bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
ret = vpopc(bits) # [TILE_SIZE, 32]
offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
def clear_sums(n_cols, device, MEMSET_BLOCK=512):
cdiv = triton.cdiv
blocks = cdiv(n_cols, MEMSET_BLOCK)
out_ret = torch.empty((blocks * MEMSET_BLOCK, ), device=device, dtype=torch.int32)
_sum_bitmatrix_memset[(blocks, )](out_ret, MEMSET_BLOCK)
return out_ret
def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
assert partials_block_size is not None
cdiv = triton.cdiv
PARTIALS_BLOCK_M = partials_block_size
n_rows, n_cols = x.shape
n_rows_max = x.shape_max[0]
assert out_ret.shape == (n_cols, )
TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
pids_x = cdiv(n_rows_max, BLOCK_MM)
pids_y = cdiv(n_cols, 32)
out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32)
out_partials = torch.transpose(out_partials, 0, 1)
# output tensors
_sum_bitmatrix_rows[(pids_x, pids_y)](
x.storage.data, n_rows, x.stride(0), x.stride(1), # input
out_ret, # output [final reduction]
out_partials, out_partials.stride(0), out_partials.stride(1),
out_partials.shape[1], # output [partial reductions]
BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants
num_warps=8)
out_partials = out_partials[:cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
return out_ret, out_partials

396
triton_kernels/routing.py Normal file
View File

@ -0,0 +1,396 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/routing.py
# Triton is licensed under the MIT License.
import torch
import triton
from dataclasses import dataclass, field
from .routing_details._routing_compute import _combined_routing_compute
from .routing_details._routing_compute import _combined_routing_memset
from .routing_details._routing_compute import _routing_clear_bitmatrix
from .routing_details._expt_data import _expt_data_memset
from .routing_details._expt_data import _expt_data_compute
from .target_info import is_hip
@dataclass
class GatherIndx:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ScatterIndx:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx: torch.Tensor
dst_indx: torch.Tensor
@dataclass
class ExptData:
# hist[i] is the number of tokens routed to expert i
hist: torch.Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw: torch.Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad: dict[int, torch.Tensor]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map: dict[int, torch.Tensor]
def __post_init__(self):
if self.hist is not None:
assert self.hist.dtype == torch.int32
if self.token_offs_raw is not None:
assert self.token_offs_raw.dtype == torch.int32
if self.token_offs_pad is not None:
for v in self.token_offs_pad.values():
assert v.dtype == torch.int32
if self.block_pid_map is not None:
for v in self.block_pid_map.values():
assert v.dtype == torch.int32
@dataclass
class RoutingData:
gate_scal: torch.Tensor = field()
expt_hist: torch.Tensor = field()
n_expts_tot: int = field()
n_expts_act: int = field()
expt_data: ExptData = None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt: int = field(default=None)
def n_blocks(self, n_rows, block_m):
if n_rows <= self.n_expts_tot:
return n_rows
else:
return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1
# --------------------------
# sort tokens by expert
# --------------------------
class SortTokens(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
HIST_BLOCK_M = 32
INDX_OFFS_BLOCK_M = 512
MEMSET_BLOCK = 1024
cdiv = triton.cdiv
device = expt_scal.device
dtype = expt_scal.dtype
n_tokens_raw, _ = bitmatrix.shape
n_tokens_pad, n_expts_act = expt_scal.shape
n_gates_pad = n_tokens_pad * n_expts_act
hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
hist = hist[:n_expts_tot]
assert hist.dtype == torch.int32
# scratchpad
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
# output
topk_indx = combined_indx[:n_gates_pad]
gate_indx = combined_indx[n_gates_pad:]
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
hist, n_expts_tot, n_gates_pad)
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
_combined_routing_memset[(blocks1a + blocks1b, )](
combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
token_offs_combined, token_offs_combined.stride(0), #
blocks1a, block_pid_map, #
block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
)
indx_offs = partial_hist
_combined_routing_compute[(blocks2a + blocks2b, )](
topk_indx, gate_indx, gate_scal, # outputs
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
expt_offs, n_tokens_raw, # input shape
HIST_BLOCK_M, n_expts_act, # constants
hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
)
ctx.n_tokens_raw = n_tokens_raw
ctx.n_tokens_pad = n_tokens_pad
ctx.n_expts_act = n_expts_act
ctx.save_for_backward(gate_indx)
return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
@staticmethod
def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
(gate_indx, ) = ctx.saved_tensors
dgate_scal = dgate_scal[gate_indx]
dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
return dgate_scal, None, None, None
def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
# --------------------------
# prune routing
# --------------------------
class PruneRouting(torch.autograd.Function):
@staticmethod
def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
from .compaction import compaction
n_tokens_pad = expt_scal.shape[0]
assert n_expts_tot % simulated_ep == 0
_routing_clear_bitmatrix[(n_tokens_pad, )](
bitmatrix.storage.data,
bitmatrix.storage.data.stride(0),
bitmatrix.storage.data.stride(1),
bitmatrix.storage.data.shape[1],
n_expts_tot // simulated_ep,
BLOCK_N=512,
)
# perform compaction to update expt_scal / expt_indx
expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
n_expts_tot = n_expts_tot // simulated_ep
bitmatrix.shape[-1] = n_expts_tot
return expt_scal, expt_indx, bitmatrix
def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep)
# --------------------------
# expt_data
# --------------------------
def log2_power_of_two(x):
assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
return x.bit_length() - 1
block_m_log2_start = 4
def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
MEMSET_BLOCK = 512
HIST2_BLOCK_M = 512
device = expt_hist.device
n_expts_tot = n_expts_tot
cdiv = triton.cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end = 9 if is_hip() else 8
block_m_num = block_m_log2_end - block_m_log2_start
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
# allocate memory
pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
dtype = torch.int32
token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device)
token_offs_raw = token_offs_combined[0][:n_expts_tot + 1]
token_offs_pad = token_offs_combined[1:]
block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device)
memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
# compute outputs
token_offs_pad = token_offs_pad[:, :n_expts_tot + 1]
block_pid_map = block_pid_map[:, :max_n_tiles]
blocks1 = memset_grid + block_m_num + 1
blocks2 = n_expts_tot * block_m_num
return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num
def _unpack_into_dict(x):
block_m_log2_end = block_m_log2_start + x.shape[0]
x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
return x
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
if expt_hist is None:
return ExptData(None, None, None, None)
# this just computes the kernel arguments:
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
expt_hist, n_expts_tot, n_gates)
_expt_data_memset[(blocks1, )](
expt_hist, n_expts_tot, #
token_offs_combined, token_offs_combined.stride(0), #
block_pid_map, #
block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
num_warps=4)
_expt_data_compute[(blocks2, )](
expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
num_warps=4)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
# --------------------------
# routing
# --------------------------
def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
expt_scal, expt_indx, n_expts_tot, bitmatrix)
token_offs_pad = _unpack_into_dict(token_offs_pad)
block_pid_map = _unpack_into_dict(block_pid_map)
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None):
from .topk import topk
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, #
apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
n_expts_tot = logits.shape[-1] // simulated_ep
# mutate bitmatrix
if simulated_ep > 1:
expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep)
return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act)
# --------------------------
# torch reference
# --------------------------
def compute_expt_data_torch(hist, n_expts_tot, n_gates):
# offset for each experts
device = hist.device
token_offs_raw = torch.cumsum(hist, dim=0)
token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
token_offs_raw = token_offs_raw.int()
# maximum number of tiles for all values of `block_m` considered
block_ms = [16, 32, 64, 128]
if is_hip():
block_ms.append(256)
if n_gates <= n_expts_tot:
max_n_tiles = n_gates
else:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
# fill up tile offset/infos for each block
token_offs_pad = dict()
block_pid_map = dict()
for block_m in block_ms:
n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))
token_offs_pad[block_m] = token_offs_pad[block_m].int()
# compute data required to drive ragged batch matmul
block_pid_map[block_m] = -torch.ones(max_n_tiles, dtype=torch.int32, device=device)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col = torch.arange(max_n_tiles, device=device)
map_vals = torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
mask = col[None, :] < n_tiles[:, None]
block_pid_map[block_m].index_put_((map_idxs[mask], ), map_vals.int()[mask])
return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
def topk_torch(vals, k, expt_indx, has_user_provided_indx=False):
# topk of experts
if has_user_provided_indx:
tk_indx = expt_indx
else:
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
tk_indx = tk_indx.long()
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
tk_indx = tk_indx.int()
return tk_val, tk_indx
def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
has_user_provided_indx = expt_indx is not None
n_gates_pad = logits.shape[0] * n_expts_act
if n_rows is not None:
logits = logits[:n_rows, :]
_, n_expts_tot = logits.shape
if sm_first:
logits = torch.softmax(logits, dim=-1)
expt_scal, expt_indx = topk_torch(logits, n_expts_act, expt_indx, has_user_provided_indx=has_user_provided_indx)
if not sm_first:
expt_scal = torch.softmax(expt_scal, dim=-1)
# sort each token's selections by expert
if not has_user_provided_indx:
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
expt_scal = torch.gather(expt_scal, 1, sort_indices)
# flatten topk data
expt_scal = expt_scal.reshape(-1)
expt_indx = expt_indx.reshape(-1).to(torch.int32)
# sort by expert_id so experts are contiguous for the matmul
topk_indx = torch.argsort(expt_indx, stable=True)
gate_indx = torch.argsort(topk_indx, stable=True)
gate_scal = expt_scal[topk_indx]
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts
# pack the matmul data structure
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
# compute expt_data
expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,68 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/routing_details/_expt_data.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
@triton.jit
def _cdiv_pow2(n, log2_k):
return (n + ((1 << log2_k) - 1)) >> log2_k
@triton.jit
def _expt_data_memset(Hist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
SIZES: tl.constexpr, BLOCK: tl.constexpr):
pid = tl.program_id(0)
if pid <= SIZES:
MDStarts += pid * tile_starts_stridem
x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
for i in range(0, n_expts_tot + 1, BLOCK):
offs_n = tl.arange(0, BLOCK) + i
mask_n0 = offs_n < n_expts_tot
hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
tl.store(Tile_ptrs, tile_starts - hist_tile)
Tile_ptrs += BLOCK
else:
pid -= (SIZES + 1)
TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
tl.store(TileInfoOut, 0xffffffff)
@triton.jit
def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
SIZES: tl.constexpr, BLOCK: tl.constexpr):
pid = tl.program_id(0)
expt_id = pid // SIZES
buff_id = pid % SIZES
MDTileStarts += buff_id * tile_starts_stridem
MDTileInfo += buff_id * tile_info_stridem
n_tokens = tl.load(Hist + expt_id)
tile_dim_log2 = first_tile_dim_log2 + buff_id
n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
tile_off = tl.load(MDTileStarts + expt_id)
MDTileInfo += tile_off
for block_off in range(0, n_blocks, BLOCK):
block_offs = block_off + tl.arange(0, BLOCK)
data = (block_offs << 16) + expt_id
tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)

View File

@ -0,0 +1,152 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/routing_details/_routing_compute.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
from ._expt_data import _expt_data_compute, _expt_data_memset
@triton.jit
def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram
BLOCK_N: tl.constexpr):
loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
for i in range(loop_iterations):
offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_n < hist_size
hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
tok_starts = tl.cumsum(hist2, 0) - hist2 + x
x += tl.sum(hist2, 0)
tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
offs_n += BLOCK_N
@triton.jit
def _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id):
offs_m = tl.arange(0, BLOCK_M)
# iterate over input data
curr_sum = 0
for _ in range(0, shape_pm, BLOCK_M):
offs = offs_m * stride_pm + expt_id * stride_pn
curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
out = tl.cumsum(curr, 0) + curr_sum
curr_sum += tl.sum(curr, 0)
tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
offs_m += BLOCK_M
@triton.jit
def _keyed_add(x, y):
# we keep the key in the upper 16 bits of a uint32:
key_mask: tl.constexpr = 0xffff0000
kx = x & key_mask
ky = y & key_mask
z = tl.where(kx == ky, x + y - kx, y)
return z
@triton.jit
def _routing_compute_indx(pid_m, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
stride_pn, TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
n_tokens = tl.load(n_tokens)
n_gates = n_tokens * N_EXPTS_ACT
tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
# stable-sort by expert ID:
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
kv_pairs = tl.sort(kv_pairs, 0)
expert = kv_pairs >> 16
offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff)
mask = expert != 0xffff
gate_scal = tl.load(ExptScal + offs, mask=mask)
# compute run lengths in expert-sorted order:
x = (kv_pairs & 0xffff0000 | 0x00000001)
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
gates += tl.load(TokensStart + expert, mask=mask)
gates += exclusive_run_lengths
tl.store(ScatterIndx + offs, gates, mask=mask)
tl.store(GatherIndx + gates, offs, mask=mask)
tl.store(GateScal + gates, gate_scal, mask=mask)
@triton.jit
def _combined_routing_compute(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr, Hist,
MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
SIZES: tl.constexpr, BLOCK: tl.constexpr, blocks2a):
pid = tl.program_id(0)
if pid < blocks2a:
_expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
SIZES, BLOCK)
else:
pid -= blocks2a
_routing_compute_indx(pid, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
stride_pn, TokensStart, n_tokens, BLOCK_M, N_EXPTS_ACT)
@triton.jit
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
pid_m = tl.program_id(0)
cutoff_word = cutoff // 32
cutoff_bit = cutoff % 32
cutoff_mask = (1 << (cutoff_bit)) - 1
for start_n in range(0, shape_bn, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn)
values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
values = tl.where(offs_n > cutoff_word, 0, values)
tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn)
@triton.jit
def _combined_routing_memset(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
n_expts_tot, PartialHist, shape_pm, stride_pm, stride_pn, MDStarts, tile_starts_stridem,
blocks1a, MDTileInfo, first_tile_dim_log2, SIZES: tl.constexpr, BLOCK_A: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):
"""
This kernel essentially combines 6 different pieces of functionality,
statically branching on the value of tl.program_id(0) to decide which
codepath to take.
pid == 0: create the token cumsum
1 <= pid <= SIZES: create a tile cumsum
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
pid == blocks1a + n_expts_tot: compute_expt_offs
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
As each of these is a relatively trivial workload, launching them from
this single trampoline is beneficial as they can execute on different
streaming multiprocesses in parallel.
"""
pid = tl.program_id(0)
if pid < blocks1a:
_expt_data_memset(ExpertHist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
SIZES, BLOCK_A)
elif pid == n_expts_tot + blocks1a:
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
elif pid < n_expts_tot + blocks1a:
_routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a)
else:
offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
mask = offs < size
tl.store(Indx + offs, sentinel, mask=mask)

View File

@ -0,0 +1,139 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/specialize.py
# Triton is licensed under the MIT License.
import inspect
import re
import textwrap
import types
import triton
def cacheable(f):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g = f()
g.fn.__name__ = f.__name__
g.fn.__module__ = f.__module__
g.fn.__qualname__ = f.__qualname__
g.__name__ = f.__name__
g.__module__ = f.__module__
g.__qualname__ = f.__qualname__
g._fn_name = f"{f.__module__}.{f.__qualname__}"
return g
def define_kernel(src, module, attrs=None, **extra_globals):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def _empty_fn():
pass
gdict = dict(**(_empty_fn.__globals__))
gdict.update(extra_globals)
f = types.FunctionType(_empty_fn.__code__, gdict)
f.__module__ = module.__name__
src = textwrap.dedent(src)
src = src[src.find("def "):]
stored_functions = []
function_name = src[4:].split("(")[0].strip()
exec_globals = gdict
exec_globals.update({"stored_functions": stored_functions})
exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
f.__signature__ = inspect.signature(stored_functions[0])
f.__name__ = function_name
f.__doc__ = stored_functions[0].__doc__
if attrs is None:
attrs = dict()
f = triton.JITFunction(f, **attrs)
f._unsafe_update_src(src)
return f
def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
assert isinstance(fn, triton.runtime.jit.JITFunction)
if name is None:
name = f"{fn.__name__}"
# Get original source code
src = inspect.getsource(fn.fn)
src = textwrap.dedent(src)
lines = src.split("\n")
# Skip decorator and def line
def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
# separate header vs body LOC
header_end = def_idx
while not lines[header_end].rstrip().endswith(":"):
header_end += 1
body_lines = lines[header_end + 1:]
header_lines = lines[def_idx:header_end + 1]
# clean-up header
header_clean = [
l.split("#", 1)[0].strip() # keep code, discard comment
for l in header_lines
if l.split("#", 1)[0].strip() # skip blankaftercomment lines
]
# decompose arguments
header_src = " ".join(header_clean) # turn it into a single line
m = re.search(r"\((.*)\)\s*:", header_src)
if not m:
raise ValueError("Could not parse function header")
args_str = m.group(1)
args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
non_specialized_args = []
for arg in args:
arg_key = arg.split(":")[0].split("=")[0].strip()
new_args = tuples.get(arg_key, [arg])
if arg_key not in constants:
non_specialized_args += new_args
# add global symbols
spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)}
globals = spec_fns | fn.get_capture_scope()
# build new source code and define kernel dynamically
new_signature = f"def {name}({', '.join(non_specialized_args)}):"
constexpr_lines = [
f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items()
]
tuple_lines = [
f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items()
]
new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines)
# find function parameters
sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
params = list(sig.parameters.values())[2:]
attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
# make a new repr which appends the repr of the specialized functions.
base_repr = attrs["repr"]
def new_repr(specialization):
ret = base_repr(specialization)
for spec_fn in spec_fns.values():
spec_repr = spec_fn.repr(None)
if spec_repr:
spec_repr = spec_repr.strip("_")
if spec_repr:
ret += f"_{spec_repr}"
return ret
attrs["repr"] = new_repr
if do_not_specialize:
attrs["do_not_specialize"] = do_not_specialize
ret = define_kernel(new_src, module, attrs, **globals)
return ret

104
triton_kernels/swiglu.py Normal file
View File

@ -0,0 +1,104 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/swiglu.py
# Triton is licensed under the MIT License.
from dataclasses import dataclass
from triton_kernels.numerics import InFlexData, OutFlexData
import torch
import triton
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
from triton_kernels import target_info
@dataclass(frozen=True)
class FlexCtx:
out_data: OutFlexData = OutFlexData()
inp_data: InFlexData = InFlexData()
saturate_inf: bool = False
@dataclass(frozen=True)
class PrecisionConfig:
limit: float
flex_ctx: FlexCtx = FlexCtx()
swiglu_fn = _swiglu_fn
class SwiGLU(torch.autograd.Function):
@staticmethod
def forward(ctx, a, alpha, precision_config, routing_data):
N = a.shape[-1]
M = a.numel() // N
assert a.stride()[-1] == 1
assert a.shape[-1] % 2 == 0
out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
flex_ctx = precision_config.flex_ctx
# optimization hyperparameters
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
num_warps = 4
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
# launch semi-persistent kernel
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
num_sms = target_info.num_sms()
if routing_data is not None:
waves_per_sm = 32 if target_info.is_hip() else 128
num_pid = num_sms * (waves_per_sm // num_warps)
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
else:
M_BLOCKS = triton.cdiv(M, BLOCK_M)
if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
grid = (8 * num_sms, )
else:
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
n_tokens = None
if routing_data is not None:
n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
_swiglu[grid](
flex_ctx.out_data.reinterpret(out),
flex_ctx.out_data.expected_scale,
flex_ctx.out_data.actual_scale,
flex_ctx.out_data.checksum_scale,
flex_ctx.inp_data.reinterpret(a),
flex_ctx.inp_data.scale,
alpha,
M,
N // 2,
a.shape[-1],
1,
out.shape[-1],
1,
precision_config.limit,
n_tokens,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
EVEN_N=(N // 2) % BLOCK_N == 0,
M_BLOCKS=M_BLOCKS,
N_BLOCKS=N_BLOCKS,
flexpoint_saturate_inf=flex_ctx.saturate_inf,
num_warps=num_warps,
**kwargs,
)
out = out.view(a.shape[:-1] + out.shape[-1:])
return out
def swiglu(a, alpha, precision_config, routing_data=None):
return SwiGLU.apply(a, alpha, precision_config, routing_data)
def swiglu_torch(a, alpha, precision_config):
limit = precision_config.limit
a_gelu = a[..., ::2]
if limit is not None:
a_gelu = a_gelu.clamp(max=limit)
a_linear = a[..., 1::2]
if limit is not None:
a_linear = a_linear.clamp(min=-limit, max=limit)
out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
out = out_gelu * (a_linear + 1)
return out

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,106 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/swiglu_details/_swiglu.py
# Triton is licensed under the MIT License.
from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale
import triton
import triton.language as tl
@triton.jit
def clip(x, limit, clip_lower: tl.constexpr):
res = tl.minimum(x, limit)
if clip_lower:
res = tl.maximum(-limit, res)
return res
@triton.jit
def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
return tl.max(tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True), axis=1)
def swiglu_repr(specialization):
signature = specialization.signature
constants = specialization.constants
convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
return f"_swiglu_{dtypes}_{blocks}"
def swiglu_launch_metadata(grid, kernel, args):
M, N = args["M"], args["N"]
ret = dict()
ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
A, Out = args["A"], args["Out"]
ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
return ret
@triton.jit
def compute_swiglu(gelu, linear, scale, alpha, limit):
gelu = gelu.to(tl.float32) * scale
if limit is not None:
gelu = clip(gelu, limit, clip_lower=False)
linear = linear.to(tl.float32) * scale
if limit is not None:
linear = clip(linear, limit, clip_lower=True)
s = gelu / (1 + tl.exp(-alpha * gelu))
return tl.fma(s, linear, s) # (s * (linear + 1))
@triton.jit(repr=lambda _: "_swiglu")
def _swiglu_fn(input, alpha, limit):
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
return compute_swiglu(gelu, linear, 1.0, alpha, limit)
@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
if NTokens is not None:
M = tl.load(NTokens)
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
a_scale = load_scale(AScale)
out_expected_scale = load_scale(OutExpectedScale)
for pid in tl.range(tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2):
pid_m = (pid // N_BLOCKS)
pid_n = (pid % N_BLOCKS)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = off_m < M
mask_n = off_n < N
packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
packed_mask_n = packed_off_n < N
packed_mask_n = tl.max_constancy(packed_mask_n, [16])
# load a
packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
if EVEN_N:
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
else:
if pid_n * BLOCK_N + BLOCK_N <= N:
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
else:
packed_mask = mask_m[:, None] & packed_mask_n[None, :]
a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.)
a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if OutActualScale is not None:
absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
local_max = tl.maximum(local_max, absmax)
out = float_to_flex(out, out_expected_scale,
None, # ActualScale: local absmax is tracked and updated after the loop
OutChecksumScale, None, Out, flexpoint_saturate_inf)
mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask)
update_scale(local_max, OutActualScale, Out)

View File

@ -0,0 +1,58 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/target_info.py
# Triton is licensed under the MIT License.
import torch
import triton
import triton.language as tl
from triton.language.target_info import (
cuda_capability_geq,
is_cuda,
is_hip,
is_hip_cdna3,
is_hip_cdna4,
)
__all__ = [
"cuda_capability_geq",
"get_cdna_version",
"has_tma_gather",
"has_native_mxfp",
"is_cuda",
"is_hip",
"is_hip_cdna3",
"is_hip_cdna4",
"num_sms",
]
@triton.constexpr_function
def get_cdna_version():
"""
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
hardware or unsupported architecture
"""
target = tl.target_info.current_target()
if target.backend != 'hip':
return -1
if target.arch == 'gfx942':
return 3
if target.arch == 'gfx950':
return 4
return -1
@triton.constexpr_function
def has_tma_gather():
return cuda_capability_geq(10, 0)
@triton.constexpr_function
def has_native_mxfp():
return cuda_capability_geq(10, 0)
def num_sms():
return torch.cuda.get_device_properties(0).multi_processor_count

219
triton_kernels/tensor.py Normal file
View File

@ -0,0 +1,219 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor.py
# Triton is licensed under the MIT License.
from dataclasses import dataclass, fields
from typing import Type
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.tools.ragged_tma import create_ragged_descriptor
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
from .target_info import cuda_capability_geq
from .tensor_details.layout import Layout, StridedLayout
@dataclass
class Storage:
data: torch.Tensor
layout: Layout = None
def __post_init__(self):
assert isinstance(self.data, torch.Tensor)
if self.layout is None:
self.layout = StridedLayout(self.data.shape)
@property
def device(self):
return self.data.device
def is_tma_compliant(self):
# TMAs didn't exist until Hopper
if not cuda_capability_geq(9, 0):
return False
# TMAs only exist for 2D, 3D, 5D inputs
if len(self.data.shape) not in [2, 3, 5]:
return False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides = list(self.data.stride())
try:
major_dim = strides.index(1)
except ValueError:
major_dim = -1
ndim = self.data.ndim
bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
return all(compliant)
def make_dense_tma(self, block_shape, transpose=False):
strides = list(self.data.stride())
shape = list(self.data.shape)
transpose = self.data.stride()[-1] != 1
if transpose:
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
shape = shape[:-2] + [shape[-1], shape[-2]]
strides = strides[:-2] + [strides[-1], strides[-2]]
if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
indx = strides.index(1)
block_shape[indx] = block_shape[indx] // 2
if shape[-1] % 128 != 0:
raise ValueError("inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.")
block_shape = self.layout.swizzle_block_shape(block_shape)
return TensorDescriptor(self.data, shape, strides, block_shape)
def make_tma(self, block_shape, mode, transpose=False):
if mode in ["dense", "gather", "scatter"]:
return self.make_dense_tma(block_shape, transpose)
assert mode == "ragged"
ragged_dim = len(self.data.shape) - 2
return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
@dataclass
class IntegerType:
bitwidth: int
@dataclass
class FloatType:
bitwidth_exponent: int
bitwidth_mantissa: int
is_signed: bool
def __post_init__(self):
self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
BIT = IntegerType(1)
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
def bitwidth(type: IntegerType | FloatType | torch.dtype):
if isinstance(type, torch.dtype):
return type.itemsize * 8
return type.bitwidth
@dataclass
class Tensor:
storage: Storage | torch.Tensor
dtype: IntegerType | FloatType | torch.dtype = None
shape: list[int] | None = None
shape_max: list[int] | None = None
def __post_init__(self):
# set storage
if isinstance(self.storage, torch.Tensor):
self.storage = Storage(self.storage)
# initialize dtype
if self.dtype is None:
self.dtype = self.storage.data.dtype
if bitwidth(self.dtype) < 8 and self.shape is None:
raise ValueError("shape must be provided for sub-byte types")
# initialize shape
if self.shape is None:
self.shape = list(self.storage.data.shape)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int = lambda s: isinstance(s, int)
is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
# initialize shape_max
if self.shape_max is None:
self.shape_max = [None] * len(self.shape)
for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
if smax is not None and not is_int(smax):
raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
if smax is None:
self.shape_max[i] = s
# validate shape_max: all elements must be `int`
assert all(map(is_int, self.shape_max))
# torch compatibility layer
@property
def ndim(self):
return len(self.shape)
@property
def device(self):
return self.storage.device
def stride(self, i=None):
return self.storage.data.stride() if i is None else self.storage.data.stride(i)
def data_ptr(self):
return self.storage.data.data_ptr()
def numel(self):
return self.storage.data.numel()
def element_size(self):
return bitwidth(self.dtype) // 8
@property
def data(self):
t = self.storage
return t.data if isinstance(t, Storage) else t
def dim(self):
return self.ndim
def size(self, i=None):
if i is None:
return self.shape
return self.shape[i]
@dataclass
class Bitmatrix(Tensor):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad: torch.Tensor = None
def __init__(self, storage, shape, shape_max=None, scratchpad=None):
super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
self.scratchpad = scratchpad
def sum(self, partials_block_size):
_, n_cols = self.shape
dev = self.device
if self.scratchpad is None:
self.scratchpad = clear_sums(n_cols, dev)
out_ret = self.scratchpad[:n_cols]
self.scratchpad = None # throw error if we try to sum again
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
def get_layout(tensor: torch.Tensor | Tensor | None):
if tensor is None:
return None
if isinstance(tensor, Tensor):
return tensor.storage.layout
return StridedLayout
def wrap_torch_tensor(torch_tensor, dtype=None):
if dtype is None:
dtype = torch_tensor.dtype
shape = list(torch_tensor.shape)
shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
assert isinstance(tensor, Tensor)
old_storage = tensor.storage
old_data = old_storage.layout.unswizzle_data(old_storage.data)
new_layout = layout_cls(old_data.shape, **layout_kwargs)
new_data = new_layout.swizzle_data(old_data)
attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
return Tensor(Storage(new_data, new_layout), **attrs)

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,44 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout.py
# Triton is licensed under the MIT License.
from .layout_details.base import Layout
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
from .layout_details.blackwell_value import BlackwellMXValueLayout
from .layout_details.hopper_scale import HopperMXScaleLayout
from .layout_details.hopper_value import HopperMXValueLayout
from .layout_details.cdna4_scale import CDNA4MXScaleLayout
from .layout_details.strided import StridedLayout
from ..target_info import cuda_capability_geq, is_hip_cdna4
__all__ = [
"Layout",
"BlackwellMXValueLayout",
"BlackwellMXScaleLayout",
"HopperMXScaleLayout",
"HopperMXValueLayout",
"CDNA4MXScaleLayout",
"StridedLayout",
]
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
if cuda_capability_geq(10):
# return StridedLayout, dict()
return BlackwellMXValueLayout, dict()
elif cuda_capability_geq(9):
return HopperMXValueLayout, {"mx_axis": mx_axis}
else:
return StridedLayout, dict()
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
if is_hip_cdna4():
return CDNA4MXScaleLayout, dict()
else:
if cuda_capability_geq(10):
return BlackwellMXScaleLayout, dict()
elif cuda_capability_geq(9):
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
return StridedLayout, dict()

View File

@ -0,0 +1 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.

View File

@ -0,0 +1,23 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/base.py
# Triton is licensed under the MIT License.
from abc import ABC, abstractmethod
class Layout(ABC):
def __init__(self, shape) -> None:
self.initial_shape = shape
@abstractmethod
def swizzle_data(self, data):
pass
@abstractmethod
def unswizzle_data(self, data):
pass
@abstractmethod
def swizzle_block_shape(self, block_shape):
pass

View File

@ -0,0 +1,62 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py
# Triton is licensed under the MIT License.
import math
import triton
import triton.language as tl
import torch
from .base import Layout
SWIZZLE_ALIGN_INNER = 8
SWIZZLE_SIZE_INNER = 4
SWIZZLE_SIZE_OUTER = 128
class BlackwellMXScaleLayout(Layout):
name: str = "BLACKWELL_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
*self.leading_shape, self.K, self.N, = shape
self.B = math.prod(self.leading_shape)
self.ALIGN_K = 8
self.ALIGN_N = 128
self.SWIZZLE_K = 4
self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
def swizzle_data(self, data):
data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K))
data = data.transpose(-1, -2).contiguous()
data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K,
self.SWIZZLE_K)
data = data.transpose(2, 4).contiguous()
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
return data
def unswizzle_data(self, data):
data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32,
self.SWIZZLE_K)
data = data.transpose(2, 4)
data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
data = data.transpose(-1, -2)
return data[..., :self.K, :self.N]
def swizzle_block_shape(self, block_shape):
MX_PACK_DIVISOR = 32
MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
@triton.jit
def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER):
shape_0: tl.constexpr = x.shape[0]
shape_1: tl.constexpr = x.shape[1]
tl.static_assert(shape_1 % SIZE_OUTER == 0)
tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
x = x.reshape(shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER)
x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
return x

View File

@ -0,0 +1,37 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value.py
# Triton is licensed under the MIT License.
import torch
from .base import Layout
class BlackwellMXValueLayout(Layout):
name: str = "BLACKWELL_VALUE"
def __init__(self, shape) -> None:
super().__init__(shape)
self.shape = shape
def swizzle_data(self, data):
# permutation needed to make `data` row major
to_row_major = sorted(range(data.ndim), key=lambda d: (data.stride(d), d))[::-1]
# permutation needed to retrieve original order
inv = [0] * data.ndim
for i, d in enumerate(to_row_major):
inv[d] = i
# leading dimension must be padded to be aligned to 128
align_dim = lambda x: (x + 128 - 1) // 128 * 128
major_dim = data.stride().index(1)
pad = align_dim(data.shape[major_dim]) - data.shape[major_dim]
data = torch.nn.functional.pad(data.permute(to_row_major), (0, pad)).permute(inv)
return data
def unswizzle_data(self, data: torch.Tensor):
# Trim padding along all dims back to the original shape recorded at init.
assert data.ndim == len(self.shape), "Rank mismatch between data and recorded shape"
sizes = [min(data.size(i), self.shape[i]) for i in range(data.ndim)]
return data[tuple(slice(0, s) for s in sizes)]
def swizzle_block_shape(self, block_shape):
return block_shape

View File

@ -0,0 +1,48 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/cdna4_scale.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
from .base import Layout
NON_K_PRESHUFFLE_BLOCK_SIZE = 32
class CDNA4MXScaleLayout(Layout):
name: str = "CDNA4_SCALE"
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
block_shape = data.shape
SCALE_K = block_shape[-2]
N = block_shape[-1]
data = data.transpose(-1, -2)
data = data.view(-1, N // NON_K_PRESHUFFLE_BLOCK_SIZE, 2, 16, SCALE_K // 8, 2, 4, 1)
data = data.permute(0, 1, 4, 6, 3, 5, 2, 7).contiguous()
if len(block_shape) == 3:
E = block_shape[0]
data = data.reshape(E, N // 32, SCALE_K * 32)
else:
assert len(block_shape) == 2
data = data.reshape(N // 32, SCALE_K * 32)
return data.transpose(-1, -2)
def unswizzle_data(self, data):
raise NotImplementedError()
def swizzle_block_shape(self, block_shape):
SCALE_K = block_shape[-2]
N = block_shape[-1]
return block_shape[:-2] + [N // 32, SCALE_K * 32]
@triton.jit
def unswizzle_mx_scale_cdna4(x, BLOCK_N: tl.constexpr, MX_SCALE_BLOCK_K: tl.constexpr,
N_PRESHUFFLE_FACTOR: tl.constexpr = NON_K_PRESHUFFLE_BLOCK_SIZE):
x = x.reshape(BLOCK_N // N_PRESHUFFLE_FACTOR, MX_SCALE_BLOCK_K // 8, 4, 16, 2, 2, 1)
x = x.permute(0, 5, 3, 1, 4, 2, 6)
x = x.reshape(BLOCK_N, MX_SCALE_BLOCK_K)
return x

View File

@ -0,0 +1,84 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py
# Triton is licensed under the MIT License.
import torch
import triton
import triton.language as tl
from .base import Layout
class HopperMXScaleLayout(Layout):
name: str = "HOPPER_SCALE"
def __init__(self, shape, mx_axis, num_warps=8) -> None:
assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
super().__init__(shape)
self.mx_axis = mx_axis
self.num_warps = num_warps
*self.leading_shape, _, _ = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.contiguous().mT
return data
def swizzle_data(self, data):
data = self._maybe_mT(data).contiguous()
*batch, M, K = data.shape
SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
SWIZZLE_ALIGN_K = 2
pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
*batch, M, K = data.shape
assert data.is_contiguous()
assert M % (
2 * self.num_warps * 2 *
8) == 0 and K % 2 == 0, f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
b = len(batch)
data = data.reshape(*batch, M // (2 * self.num_warps * 2 * 8), 2, self.num_warps, 2, 8, K // 2, 2)
perm = [0, 2, 5, 1, 4, 6, 3]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.flatten(-5, -1)
data = data.flatten(-3, -2)
assert data.shape[-2] == M // 32
assert data.shape[-1] == K * 32
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
*batch, M, K = data.shape
b = len(batch)
data = data.reshape(*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2)
perm = [0, 3, 1, 6, 4, 2, 5]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 32, K // 32)
data = self._maybe_mT(data)
return data
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_scale_hopper
"""
tl.static_assert(len(x.shape) == 2, "NYI")
# implementation assumes mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
x = x.trans(0, 3, 1, 6, 4, 2, 5)
x = x.reshape(M * 32, K // 32)
# implementation assumed mxfp data is packed along the last dimension
x = x.trans() if mx_axis == 0 else x
return x

View File

@ -0,0 +1,327 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py
# Triton is licensed under the MIT License.
import torch
import triton
import triton.language as tl
from .base import Layout
def right_shift_unsigned(x, shift):
return (x >> shift) & ((1 << (32 - shift)) - 1)
# -----------------------------------------------------------------------
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
# 1000000111000000 (first fp4)
# 1000000111000000 (second fp4)
# 1000000111000000 (third fp4)
# 0110110000000000 (fourth fp4)
# This is done so that dequantization can be done in 14 SASS instructions
# -----------------------------------------------------------------------
def _compress_fp4(x):
x = x.to(torch.int32)
return ((x & 0x8) << 12) | ((x & 0x7) << 6)
def _compress_fourth(x):
x = x.to(torch.int32)
return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
def _pack_bits(x: torch.Tensor, mx_axis: int):
x = x.contiguous()
assert x.shape[-1] % 4 == 0, "Input tensor must have a last dimension divisible by 4"
x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
x = first | right_shift_unsigned(second, 3) | right_shift_unsigned(third, 6) | fourth
assert x.is_contiguous()
x = x.view(torch.uint8)
return x
# -----------------------------------------------------------------------
# inverse operation of _pack_bits
# -----------------------------------------------------------------------
def _bf16_to_fp4e2m1(x):
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
assert x.dtype == torch.int16
s = (right_shift_unsigned(x, 15) & 0x1) << 3
em = right_shift_unsigned(x, 6) & 0x7
return (s | em).to(torch.uint8)
def _bf16x2_to_fp4e2m1x2(x):
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
assert x.dtype == torch.int32
lo = (x & 0xFFFF).to(torch.int16)
hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
ret_lo = _bf16_to_fp4e2m1(lo)
ret_hi = _bf16_to_fp4e2m1(hi)
return ret_lo | (ret_hi << 4)
def _unpack_bits(x, mx_axis: int):
x = x.view(torch.int32)
m = 0b10000001110000001000000111000000
a = (x << 1) & 0b10000000000000001000000000000000
b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
x = torch.stack(unpacked, dim=-1)
x = x.flatten(-2, -1)
x = _bf16x2_to_fp4e2m1x2(x)
return x
# -----------------------------------------------------------------------
class HopperMXValueLayout(Layout):
name: str = "HOPPER_VALUE"
def __init__(self, shape, mx_axis, mma_version=3):
super().__init__(shape)
assert mx_axis in range(len(shape))
self.mx_axis = mx_axis
self.mma_version = mma_version
*self.leading_shape, self.K, self.N, = shape
def _maybe_mT(self, data):
if self.mx_axis == len(self.leading_shape):
return data.mT
return data
def swizzle_data(self, data):
"""
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
(*, M // 4, K * 4) such that:
1) Groups contiguously all the elements owned by the same thread of 4
mma tiles along the K axis. The following animation shows a similar
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
as done here:
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
2) Moves the elements belonging to thread 4-7 to be contiguous with those
from thread 0-3. This is done to get a full cache line when loading them
from HBM.
mx_axis selects the lhs or rhs of the matmul.
WARNING: Assumes that the matmul will be done in bf16 or fp16!
Implementing it for fp8 is as easy as making the tile size (8, 8)
"""
batch = data.ndim - 2
assert batch >= 0
assert self.mma_version in (2, 3)
data = self._maybe_mT(data)
init_shape = data.shape
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
contig = (1, u8_kwidth)
scott_trick = (2, 1)
threads = (4, 4)
warp_tile = (2, 2)
k_tile = (1, 4 // u8_kwidth)
sizes = list(data.shape[:-2])
pads = []
# [rest, K, tile, threads] per dimension
for i, (a, b, c, s, d) in enumerate(zip(k_tile, warp_tile, threads, scott_trick, contig)):
pack = a * b * c * s * d
size = data.shape[batch + i]
pad = (pack - size % pack) % pack
pads += [(0, pad)]
sizes.append((size + pad) // pack)
sizes += [a, b, c, s, d]
pads = tuple(x for t in pads[::-1] for x in t)
data = torch.nn.functional.pad(data, pads)
init_shape = data.shape
# 0: rest[0]
# 1: k_tile[0]
# 2: warp_tile[0]
# 3: threads[0]
# 4: scott_trick[0]
# 5: contig[0]
# 6: rest[1]
# 7: k_tile[1]
# 8: warp_tile[1]
# 9: threads[1]
# 10: scott_trick[1]
# 11: contig[1]
data = data.view(*sizes)
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
perm = list(range(batch)) + [batch + p for p in perm]
data = data.permute(*perm).contiguous()
# These are views
data = data.flatten(-10, -1)
data = data.flatten(-3, -2)
assert data.is_contiguous()
assert data.shape[-2] == init_shape[-2] // 4
assert data.shape[-1] == init_shape[-1] * 4
# twiddle the bits
data = _pack_bits(data, self.mx_axis)
data = self._maybe_mT(data)
return data
def unswizzle_data(self, data):
data = self._maybe_mT(data)
data = _unpack_bits(data, self.mx_axis)
*batch, M, K = data.shape
# We have two times the elements if we already upcasted to bfloat16
mult = 2 if data.dtype == torch.bfloat16 else 1
assert M % 4 == 0, "M must be divisible by 4"
assert K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
data = data.reshape(*batch, M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
b = len(batch)
perm = [0, 6, 1, 3, 2, 5, 4, 7]
perm = list(range(b)) + [b + p for p in perm]
data = data.permute(*perm)
data = data.reshape(*batch, M * 4, K // 4)
data = self._maybe_mT(data)
return data[..., :self.K, :self.N]
def swizzle_block_shape(self, block_shape):
return block_shape
@triton.jit
def _unshuffle_triton(x, mma_version: tl.constexpr):
"""
Triton inverse of swizzle_mxfp4_value_hopper
"""
tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
# if mx_axis == 0:
# x = x.trans()
# We have two times the elements if we already upcasted to bfloat16
mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
M: tl.constexpr = x.shape[0]
K: tl.constexpr = x.shape[1]
tl.static_assert(M % 4 == 0, "M must be divisible by 4")
tl.static_assert(K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}")
# We are loading 8 bf16 elements per thread to use ld.global.v4
# Every u8 represents 2 mxfp4 elements
u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
x = x.reshape(M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
x = x.reshape(M * 4, K // 4)
# if mx_axis == 0:
# x = x.trans()
return x
@triton.jit
def _unpack_fp4_to_bf16_triton(x):
# For now we implement just H100 support (mul.bf16x2)
# A100 support is possible via fma
r0, r1 = tl.inline_asm_elementwise(
r"""
{
.reg .b32 b, c, d<7>, scale;
.reg .b32 bias;
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
// We add the missing bias to the scale directly
and.b32 $0, $4, 0b10000001110000001000000111000000;
mul.bf16x2 $0, $0, bias;
shl.b32 b, $4, 3;
and.b32 $1, b, 0b10000001110000001000000111000000;
mul.bf16x2 $1, $1, bias;
shl.b32 c, $4, 6;
and.b32 $2, c, 0b10000001110000001000000111000000;
mul.bf16x2 $2, $2, bias;
// Unpack last two elements
shl.b32 d0, $4, 1;
and.b32 d1, d0, 0b10000000000000001000000000000000;
shr.b32 d2, $4, 3;
and.b32 d3, d2, 0b00000001100000000000000110000000;
or.b32 d4, d1, d3;
shr.b32 d5, $4, 7;
and.b32 d6, d5, 0b00000000010000000000000001000000;
or.b32 $3, d4, d6;
mul.bf16x2 $3, $3, bias;
}
""",
constraints="=r,=r,=r,=r,r",
args=[x],
dtype=(tl.bfloat16, tl.bfloat16),
is_pure=True,
pack=4,
)
# Concat each pack of 4
x = tl.join(r0, r1)
x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
x = x.trans(0, 1, 3, 2)
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
return x
@triton.jit
def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
"""
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
(x << 0) & 0b1000000111000000
(x << 3) & 0b1000000111000000
(x << 6) & 0b1000000111000000
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
"""
# upcast values to bfloat16
tl.static_assert(len(x.shape) == 2)
tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
tl.static_assert(x.shape[1] % 4 == 0)
tl.static_assert(x.dtype == tl.uint8)
if mx_axis == 0:
x = x.trans()
x = _unpack_fp4_to_bf16_triton(x)
x = _unshuffle_triton(x, mma_version=3)
if mx_axis == 0:
x = x.trans()
# upcast scale to bfloat16
# Add bias missing from the bf16 upcasting sequence
# triton / LLVM generates terrible code for this sequence
# scale = scale.to(tl.uint16)
# scale = scale << 7
# scale = scale.to(tl.bfloat16, bitcast=True)
scale = tl.inline_asm_elementwise(
r"""
{
prmt.b32 $0, $2, 0, 0x5140;
shl.b32 $0, $0, 7;
prmt.b32 $1, $2, 0, 0x7362;
shl.b32 $1, $1, 7;
}
""",
constraints="=r,=r,r",
args=[scale],
dtype=tl.bfloat16,
is_pure=True,
pack=4,
)
# Broadcast scale
scale = scale.expand_dims(mx_axis + 1)
scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [32] + scale.shape[mx_axis + 2:])
scale = scale.reshape(x.shape)
# Combine scale and x
x = x * scale
return x

View File

@ -0,0 +1,21 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/tensor_details/layout_details/strided.py
# Triton is licensed under the MIT License.
from .base import Layout
class StridedLayout(Layout):
name: str = None
def __init__(self, shape) -> None:
super().__init__(shape)
def swizzle_data(self, data):
return data
def unswizzle_data(self, data):
return data
def swizzle_block_shape(self, block_shape):
return block_shape

199
triton_kernels/testing.py Normal file
View File

@ -0,0 +1,199 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/testing.py
# Triton is licensed under the MIT License.
import enum
import functools
import os
import subprocess
import sys
import torch
from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
def assert_equal(ref, tri):
if isinstance(ref, torch.Tensor):
assert torch.all(ref == tri)
else:
assert ref == tri
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
if tri.dtype.itemsize == 1:
ref_as_type = ref.to(tri.dtype)
if ref.dtype == tri.dtype:
assert torch.all(ref_as_type == tri)
return
ref = ref_as_type
if ref.numel() == 0:
return
if maxtol is None:
maxtol = 2e-2
if rmstol is None:
rmstol = 4e-3
"""
Compare reference values against obtained values.
"""
# cast to float32:
ref = ref.to(torch.float32).detach()
tri = tri.to(torch.float32).detach()
assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
# deal with infinite elements:
inf_mask_ref = torch.isinf(ref)
inf_mask_tri = torch.isinf(tri)
assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
refn = torch.where(inf_mask_ref, 0, ref)
trin = torch.where(inf_mask_tri, 0, tri)
# normalise so that RMS calculation doesn't overflow:
eps = 1.0e-30
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
refn *= multiplier
trin *= multiplier
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
max_err = torch.max(rel_err).item()
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
if verbose:
print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
if max_err > maxtol:
bad_idxs = torch.nonzero(rel_err > maxtol)
num_nonzero = bad_idxs.size(0)
bad_idxs = bad_idxs[:1000]
print("%d / %d mismatched elements (shape = %s) at coords %s" %
(num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
bad_idxs = bad_idxs.unbind(-1)
print("ref values: ", ref[tuple(bad_idxs)].cpu())
print("tri values: ", tri[tuple(bad_idxs)].cpu())
assert max_err <= maxtol
assert rms_err <= rmstol
class ComputeSanitizerTool(enum.Enum):
MEMCHECK = "memcheck"
RACECHECK = "racecheck"
SYNCCHECK = "synccheck"
INITCHECK = "initcheck"
def compute_sanitizer(**target_kwargs):
"""
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
to expose potential memory access errors.
This decorator requires the `request` fixture to be present.
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
Running tests under compute sanitizer requires launching subprocess and is slow,
so use sparingly
"""
def decorator(test_fn):
@functools.wraps(test_fn)
def wrapper(*args, **kwargs):
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
test_fn(*args, **kwargs)
return
import psutil
if target_kwargs.pop("clear_torch_cache", False):
# If we don't pop clear_torch_cache, it won't pass
# target_kwargs.items() <= kwargs.items() condition below.
torch.cuda.empty_cache()
tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
ppid_name = psutil.Process(os.getppid()).exe()
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
if "run_sanitizer" in kwargs:
run_compute_sanitizer &= kwargs["run_sanitizer"]
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
for tool in tools_to_check:
path = os.path.realpath(test_fn.__globals__["__file__"])
# get path of current file
env = {
"PATH": os.environ["PATH"],
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
"TORCH_SHOW_CPP_STACKTRACES": "1",
"CUDA_LAUNCH_BLOCKING": "1",
}
if "CUDA_VISIBLE_DEVICES" in os.environ:
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
assert "request_fixture" in kwargs, (
"memcheck'ed test must have a (possibly unused) `request` fixture")
test_id = kwargs["request_fixture"].node.callspec.id
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
cmd = [
"compute-sanitizer",
"--target-processes=application-only",
"--destroy-on-device-error=context",
f"--tool={tool.value}",
sys.executable,
"-m",
"pytest",
"-vsx",
cmd,
]
for opt in ["--update_checksum", "--ignore_checksum_error"]:
if opt in sys.argv:
cmd.append(opt)
out = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
env=env,
)
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
test_output = out.stdout
if type(test_output) is bytes:
test_output = test_output.decode()
fail = False
if not sanitizer_ok:
print("compute-sanitizer returned an error")
fail = True
elif out.returncode != 0:
print(
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
)
print(f"{out.returncode=}")
fail = True
if fail:
print("*****************************************************")
print("******************** TEST OUTPUT ********************")
print("*****************************************************")
print(test_output)
print("*****************************************************")
print("****************** TEST OUTPUT END ******************")
print("*****************************************************")
assert None
else:
test_fn(*args, **kwargs)
return wrapper
return decorator
def compute_actual_scale(x, dtype):
max_finite = {
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
}[dtype]
return x.abs().max() / max_finite

128
triton_kernels/topk.py Normal file
View File

@ -0,0 +1,128 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/topk.py
# Triton is licensed under the MIT License.
import torch
import triton
from triton_kernels.topk_details._topk_forward import _topk_forward
from triton_kernels.topk_details._topk_backward import _topk_backward
from triton_kernels.tensor import Tensor, Bitmatrix
from typing import Optional, Union
def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
if not isinstance(x, Tensor):
x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
x_shape_max = [x.shape[0], x.shape[1]]
x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
cdiv = lambda a, b: (a + b - 1) // b
BLOCK_M = 32
BLOCK_N = 32
BLOCK_S = 128
assert len(x.shape) == 2
assert x.shape_max[-1] < 32768
assert dim == 1
assert return_bitmatrix
n_rows, n_cols = x.shape
n_rows_max, _ = x.shape_max
dev = x.device
# scratchpad tensors
# NOTE: these are not returned
y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
if y_indx is not None:
use_provided_indx = True
else:
y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
use_provided_indx = False
# create bitmatrix in transposed memory layout:
n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
n_cols_words = n_cols_pad // 32
bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev)
bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
s_blocks = cdiv(n_cols, BLOCK_S)
s_cols = s_blocks * BLOCK_S
scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev)
pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
_topk_forward[(pids, )](
x, x.stride(0), # inputs
y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk]
bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix]
n_rows, n_cols, # shapes
scratchpad, BLOCK_S, s_blocks, # thing to memset to zero
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
)
bitmatrix_shape = [n_rows, n_cols_words * 32]
bitmatrix_shape_max = [n_rows_max, None]
bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad)
return y_vals, y_indx, bitmatrix
def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
assert dy_vals.shape[-1] == k
n_expts_pad = triton.next_power_of_2(x.shape[-1])
dx = torch.empty_like(x)
_topk_backward[(dy_vals.shape[0], )](
y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), # inputs
dx, # outputs
dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k,
N_EXPTS_PAD=n_expts_pad)
return dx
class TopK(torch.autograd.Function):
@staticmethod
def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
ctx.save_for_backward(x, y_indx)
ctx.apply_softmax = apply_softmax
ctx.k = k
ctx.n_rows = n_rows
return y_vals, y_indx, bitmatrix
@staticmethod
def backward(ctx, dy_vals, _0, _1):
x, y_indx = ctx.saved_tensors
dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
return dx, None, None, None, None, None, None
def topk(
x: Union[Tensor, torch.Tensor],
k: int,
apply_softmax: bool = True,
dim: int = 1,
return_bitmatrix: bool = True,
y_indx: Optional[torch.Tensor] = None,
n_rows: Optional[int] = None,
):
"""
Computes the top-k values and indices along a specified dimension of a tensor.
Note that the input can be either a `Tensor` or a `torch.Tensor`, but the output will always be a `torch.Tensor`.
Parameters
----------
x : Union[triton_kernels.Tensor, torch.Tensor]
Input tensor of shape (n_tokens, n_expts).
k : int
Number of top elements to retrieve.
apply_softmax : bool, default True
Whether to apply softmax to the input tensor before computing top-k.
dim : int, default 1
Dimension along which to compute top-k.
return_bitmatrix : bool, default True
A bitmatrix of shape (n_tokens, cdiv(n_expts, 32)).
Each bit on [t, b] indicates whether the b-th expert was selected for the t-th token.
y_indx : torch.Tensor, optional
Pre-allocated tensor for storing indices of top-k elements with shape (n_tokens, k).
If provided, we skip the computation of top-k indices and use this tensor instead.
n_rows : int, optional
Number of rows to apply top-k on. If None, we consider all rows in `x`.
Returns
-------
(expt_scal, expt_indx, bitmatrix) : Tuple[torch.Tensor, torch.Tensor, Bitmatrix]
"""
ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
return ret

View File

@ -0,0 +1,3 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/topk_details/__init__.py
# Triton is licensed under the MIT License.

View File

@ -0,0 +1,55 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
@triton.jit
def _topk_backward(
Yi,
stride_ym, # topk indices
DY,
stride_dym, # output gradient values
X,
stride_xm, # input values
DX,
stride_dxm, # input gradient values
n_rows,
NRows,
n_expts_tot,
APPLY_SOFTMAX: tl.constexpr,
N_EXPTS_ACT: tl.constexpr,
N_EXPTS_PAD: tl.constexpr,
):
pid_m = tl.program_id(0)
if NRows is not None:
n_rows = tl.load(NRows)
if pid_m >= n_rows:
return
Yi += pid_m * stride_ym
DY += pid_m * stride_dym
X += pid_m * stride_xm
DX += pid_m * stride_dxm
# --
offs_xn = tl.arange(0, N_EXPTS_PAD)
offs_yn = tl.arange(0, N_EXPTS_ACT)
mask_xn = offs_xn < n_expts_tot
# recompute softmax
y_indx = tl.load(Yi + offs_yn)
x = tl.load(X + y_indx)
x = x.to(tl.float32)
y = tl.softmax(x)
# compute input-gradient
dy = tl.load(DY + offs_yn)
dy = dy.to(tl.float32)
s = tl.sum(y * dy, 0)
# write-back input gradient
tl.store(DX + offs_xn, 0, mask=mask_xn)
tl.debug_barrier()
if APPLY_SOFTMAX:
dx = y * (dy - s)
else:
dx = dy
tl.store(DX + y_indx, dx)

View File

@ -0,0 +1,150 @@
# This file is vendored from the Triton project. DO NOT EDIT THIS FILE DIRECTLY.
# Source: https://github.com/triton-lang/triton/tree/v3.5.1/python/triton_kernels/triton_kernels/topk_details/_topk_forward.py
# Triton is licensed under the MIT License.
import triton
import triton.language as tl
@triton.jit
def get_topmask_and_fullmask(x):
tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits")
tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
return tm_arr, fm_arr
@triton.jit
def fpval_to_key(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) != 0, fm, tm)
@triton.jit
def key_to_fpval(x):
tm, fm = get_topmask_and_fullmask(x)
return x ^ tl.where((x & tm) == 0, fm, tm)
# stable top-k tie-breaks to value with smaller index
@triton.jit
def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
return N_EXPTS_PAD - indx
@triton.jit
def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
BLOCK_N: tl.constexpr):
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
if x_nbits < 16:
# this ensures that we leave at least 16 bits for expert index
# even if the input dtype is smaller than 16 bits:
y_nbits: tl.constexpr = 32
else:
y_nbits: tl.constexpr = x_nbits * 2
x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
x_dtype: tl.constexpr = X.dtype.element_ty
# subtract 1 from loop iterations because we peel the first (masked) iteration:
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
mask_n = offs_x_n[None, :] < n_expts_tot
# first iteration:
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
# subsequent iterations:
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
X_ptrs -= BLOCK_N
offs_x_n -= BLOCK_N
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
x = fpval_to_key(x.to(x_utype, bitcast=True))
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
# rotate expert index into upper 16 bits:
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
acc = (acc << (y_nbits - 16)) | (acc >> 16)
# sort in ascending order of expert (descending order of key)
acc = tl.sort(acc, dim=1, descending=True)
# iiii0000vvvvvvvv --> 0000iiii:
y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
# iiii0000vvvvvvvv --> vvvvvvvv:
y_values_raw = acc.to(x_utype)
y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
return y_values, y_indices
@triton.jit
def _topk_forward(X, stride_xm, # inputs
Yv, Yi, stride_ym, # topk values/indices
USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, # bitmatrix
n_rows, n_expts_tot, # shape
S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
APPLY_SOFTMAX: tl.constexpr, # constant
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
pid = tl.program_id(0)
if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
n_rows = tl.load(n_rows)
if pid < s_blocks:
tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32))
if pid * BLOCK_M >= n_rows:
# early exit:
return
tl.static_assert(BLOCK_N % 32 == 0)
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
x_dtype: tl.constexpr = X.dtype.element_ty
# load logits
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
offs_y_n = tl.arange(0, N_EXPTS_ACT)
mask_m = offs_m[:, None] < n_rows
if USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
y_indices = tl.load(Yi_ptrs, mask=mask_m)
Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
y_values = tl.load(Xv_ptrs, mask=mask_m)
else:
y_values, y_indices = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, #
N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
# normalize selected values
if APPLY_SOFTMAX:
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
# write back
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yv_ptrs, y_values, mask=mask_m)
if not USE_PROVIDED_INDX:
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
tl.store(Yi_ptrs, y_indices, mask=mask_m)
# pack into bitmatrix
y_div = y_indices // 32
y_rem = y_indices % 32
loop_iterations = N_EXPTS_PAD // BLOCK_N
for i in range(loop_iterations):
offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0)
r = tl.reduce_or(y2, axis=1)
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
tl.store(BitsPtrs, r, mask=mask_m)