mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Include triton-kernels as a packaged dependency
Signed-off-by: Anish Shanbhag <ashanbhag@nvidia.com>
This commit is contained in:
parent
a1385243e1
commit
2ab362cafe
@ -62417,6 +62417,40 @@ License: `MIT License`
|
||||
- `Homepage`: https://github.com/triton-lang/triton/
|
||||
|
||||
|
||||
## triton-kernels (1.0.0)
|
||||
|
||||
### Licenses
|
||||
License: `MIT License`
|
||||
|
||||
- `LICENSE` (from triton repository root):
|
||||
```
|
||||
Copyright 2018-2020 Philippe Tillet
|
||||
Copyright 2020-2022 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files
|
||||
(the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
```
|
||||
|
||||
### URLs
|
||||
- `Source`: https://github.com/triton-lang/triton/tree/f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f/python/triton_kernels
|
||||
|
||||
|
||||
## tritonclient (2.63.0)
|
||||
|
||||
### Licenses
|
||||
|
||||
@ -107,33 +107,10 @@ Once again, the function call works successfully, this time using a different fu
|
||||
|
||||
## Using OpenAI Triton Kernels for MoE
|
||||
|
||||
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels; enable them with the steps below:
|
||||
|
||||
1. **Build and install Triton** (tested with the commit below):
|
||||
OpenAI ships a set of Triton kernels optimized for its MoE models.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/triton-lang/triton.git
|
||||
cd triton
|
||||
# Specific commit verified with TensorRT-LLM
|
||||
git checkout f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f
|
||||
pip install -r python/requirements.txt # build-time dependencies
|
||||
pip install wheel build
|
||||
python3 setup.py bdist_wheel
|
||||
pip install ./dist/*.whl
|
||||
```
|
||||
|
||||
2. **Expose the Triton kernels to TensorRT-LLM**
|
||||
The kernels are not packaged in the wheel, so set the environment variable `TRITON_ROOT` to your Triton clone:
|
||||
|
||||
```bash
|
||||
export TRITON_ROOT=/local/user/triton
|
||||
# TensorRT-LLM expects the kernels at:
|
||||
# $TRITON_ROOT/python/triton_kernels
|
||||
```
|
||||
|
||||
3. **Select Triton as the MoE backend**
|
||||
|
||||
• **trtllm-serve** (or other similar commands) — add this snippet to the YAML file passed via `--config`:
|
||||
To use the Triton MoE backend with **trtllm-serve** (or other similar commands), add this snippet to the YAML file passed via `--config`:
|
||||
|
||||
```yaml
|
||||
moe_config:
|
||||
|
||||
@ -66,6 +66,9 @@ etcd3 @ git+https://github.com/kragniz/python-etcd3.git@e58a899579ba416449c4e225
|
||||
blake3
|
||||
soundfile
|
||||
triton==3.5.0
|
||||
# triton_kernels provides OpenAI's MoE kernels (matmul_ogs, routing, swiglu)
|
||||
# NOTE: the version below should be aligned with the triton version above
|
||||
triton-kernels @ git+https://github.com/triton-lang/triton.git@v3.5.0#subdirectory=python/triton_kernels
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
|
||||
@ -4,36 +4,15 @@ from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from triton_kernels.matmul_ogs import FlexCtx, FnSpecs, FusedActivation, PrecisionConfig, matmul_ogs
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.routing import RoutingData, routing
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
|
||||
IS_TRITON_KERNELS_AVAILABLE = True
|
||||
TRITON_KERNELS_UNAVAILABLE_REASON = ""
|
||||
|
||||
try:
|
||||
from triton_kernels.matmul_ogs import (
|
||||
FlexCtx,
|
||||
FnSpecs,
|
||||
FusedActivation,
|
||||
PrecisionConfig,
|
||||
matmul_ogs,
|
||||
)
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.routing import RoutingData, routing
|
||||
from triton_kernels.swiglu import swiglu_fn
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.tensor_details.layout import StridedLayout
|
||||
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter
|
||||
|
||||
except Exception as _e:
|
||||
IS_TRITON_KERNELS_AVAILABLE = False
|
||||
TRITON_KERNELS_UNAVAILABLE_REASON = f"{type(_e).__name__}: {_e}"
|
||||
|
||||
FlexCtx = FnSpecs = FusedActivation = PrecisionConfig = matmul_ogs = None
|
||||
InFlexData = RoutingData = routing = swiglu_fn = None
|
||||
FP4 = convert_layout = wrap_torch_tensor = None
|
||||
layout = StridedLayout = None
|
||||
TritonEPRouter = None
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter
|
||||
|
||||
|
||||
# copied from transformers.integrations.mxfp4::swizzle_mxfp4 with minor modification
|
||||
|
||||
@ -9,7 +9,6 @@ from tensorrt_llm._torch.auto_deploy.utils.pattern_matcher import (
|
||||
register_ad_pattern,
|
||||
)
|
||||
|
||||
from ...custom_ops.fused_moe.mxfp4_moe import IS_TRITON_KERNELS_AVAILABLE
|
||||
from ...utils.module import get_submodule_of_param
|
||||
from ...utils.node_utils import is_op
|
||||
from ..interface import BaseTransform, TransformInfo, TransformRegistry
|
||||
@ -220,11 +219,7 @@ class InsertMXFP4MLP(BaseTransform):
|
||||
shared_config,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
qcfg = factory.get_quant_config()
|
||||
if (
|
||||
not qcfg
|
||||
or qcfg.get("quant_method", "") != self.algo_name
|
||||
or not IS_TRITON_KERNELS_AVAILABLE
|
||||
):
|
||||
if not qcfg or qcfg.get("quant_method", "") != self.algo_name:
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
@ -1,32 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
IS_TRITON_KERNELS_AVAILABLE = False
|
||||
# We expect to find triton_kernels under $TRITON_ROOT/python/triton_kernels
|
||||
# Triton upstream commit f3067cd3bd0c29065fa4ecdb724b6f29cbabea5f has been verified.
|
||||
triton_root = os.getenv('TRITON_ROOT')
|
||||
if triton_root:
|
||||
triton_root = os.path.abspath(
|
||||
os.path.join(triton_root, 'python', 'triton_kernels'))
|
||||
if os.path.exists(triton_root) and triton_root not in sys.path:
|
||||
sys.path.insert(0, triton_root)
|
||||
assert triton.__version__ >= "3.4.0", "Triton kernels are detected but the Triton wheel is too old"
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
|
||||
PrecisionConfig, matmul_ogs)
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
IS_TRITON_KERNELS_AVAILABLE = True
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import (FlexCtx, FnSpecs, FusedActivation,
|
||||
PrecisionConfig, matmul_ogs)
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
from ...model_config import ModelConfig
|
||||
from ..linear import TensorParallelMode, load_weight_shard
|
||||
@ -1295,8 +1282,6 @@ class TritonFusedMoE(MoE):
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
raise ImportError("Triton kernels are not available.")
|
||||
if torch.cuda.get_device_capability()[0] != 9 and self.ep_size > 1:
|
||||
raise NotImplementedError(
|
||||
"TritonFusedMoE is only supported on Hopper with EP size > 1.")
|
||||
|
||||
@ -4,20 +4,15 @@ from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, matmul_ogs
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
# Reuse the common Triton import setup
|
||||
from .fused_moe.fused_moe_triton import (IS_TRITON_KERNELS_AVAILABLE,
|
||||
maybe_update_stride,
|
||||
from .fused_moe.fused_moe_triton import (maybe_update_stride,
|
||||
swizzle_weight_and_scale)
|
||||
|
||||
if IS_TRITON_KERNELS_AVAILABLE:
|
||||
from triton_kernels.matmul_ogs import (FlexCtx, PrecisionConfig, matmul_ogs)
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
from .linear import (Linear, LinearMethodBase, TensorParallelMode,
|
||||
WeightsLoadingConfig, copy_weight, load_weight_shard,
|
||||
load_weights_fused_gate_up_helper,
|
||||
@ -383,9 +378,6 @@ class TritonLinear(Linear):
|
||||
use_custom_cublas_mm: bool = False,
|
||||
lora: Optional[LoraLayer] = None,
|
||||
):
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
raise ImportError("Triton kernels are not available. "
|
||||
"Please install the required dependencies.")
|
||||
assert not use_custom_cublas_mm, "TritonLinear does not support custom cublas mm."
|
||||
|
||||
super().__init__(
|
||||
|
||||
@ -48,8 +48,6 @@ from defs.conftest import get_sm_version, is_sm_100f
|
||||
|
||||
from tensorrt_llm import LLM
|
||||
from tensorrt_llm._torch.model_config import MoeLoadBalancerConfig
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
|
||||
DeepSeekSparseAttentionConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
||||
@ -3633,8 +3631,6 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("TRITON moe backend is not available.")
|
||||
if get_sm_version() < 90:
|
||||
pytest.skip("TRITON moe backend requires Hopper or newer.")
|
||||
if moe_backend in ["CUTLASS", "TRTLLM"] and get_sm_version() < 100:
|
||||
@ -4072,8 +4068,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
@ -4143,10 +4137,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
patch_mpi_pool_session_for_env(mocker,
|
||||
{"ENABLE_CONFIGURABLE_MOE": env_value})
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
MAX_OUTPUT_LEN = 128179
|
||||
MAX_INPUT_LEN = 32768
|
||||
|
||||
@ -4221,9 +4211,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
@ -4259,8 +4246,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
|
||||
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
|
||||
{"scores_filter": "exact_match,flexible-extract"})
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
|
||||
|
||||
pytorch_config = dict(
|
||||
@ -4303,10 +4288,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
def test_w4_2gpus(self, kv_cache_dtype, moe_backend, tp_size, pp_size,
|
||||
ep_size, attention_dp, cuda_graph, overlap_scheduler,
|
||||
mocker):
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
pytorch_config = dict(
|
||||
disable_overlap_scheduler=not overlap_scheduler,
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
|
||||
@ -4381,10 +4362,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
|
||||
ids=["cutlass", "trtllm", "triton"])
|
||||
def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
MAX_OUTPUT_LEN = 128179
|
||||
MAX_INPUT_LEN = 32768
|
||||
|
||||
@ -4451,10 +4428,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
ids=["cutlass", "trtllm", "triton"])
|
||||
def test_eagle3_4gpus(self, moe_backend, one_model, overlap_scheduler,
|
||||
mocker):
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
if get_sm_version() == 90:
|
||||
pytest.skip(
|
||||
"https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue for only TP=4"
|
||||
@ -4643,10 +4616,6 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
ids=["cutlass", "trtllm", "triton"])
|
||||
def test_eagle3_2gpus(self, moe_backend, one_model, overlap_scheduler,
|
||||
mocker):
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
MAX_OUTPUT_LEN = 128179
|
||||
MAX_INPUT_LEN = 32768
|
||||
|
||||
|
||||
@ -5,9 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from _dist_test_utils import get_device_counts
|
||||
|
||||
from tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.mxfp4_moe import (
|
||||
IS_TRITON_KERNELS_AVAILABLE,
|
||||
)
|
||||
from tensorrt_llm._torch.auto_deploy.distributed.common import spawn_multiprocess_job
|
||||
|
||||
|
||||
@ -109,10 +106,6 @@ def _run_mxfp4_mlp_ep_dtype_test(num_experts: int, topk: int, rank: int, world_s
|
||||
torch.testing.assert_close(part_out, ref_out, rtol=5e-2, atol=5e-2, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_TRITON_KERNELS_AVAILABLE,
|
||||
reason="triton_kernels unavailable",
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [6, 8])
|
||||
@pytest.mark.parametrize("topk", [4]) # must be <= num_experts
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
|
||||
@ -7,8 +7,6 @@ from transformers import AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MoeConfig
|
||||
|
||||
configs = """
|
||||
@ -50,9 +48,6 @@ def dump_config_json(dst_dir):
|
||||
|
||||
@pytest.mark.parametrize("moe_backend", ["CUTLASS", "TRITON"])
|
||||
def test_gpt_oss_trtllmgen(moe_backend):
|
||||
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
prompts = [
|
||||
"How are you?",
|
||||
"Hello, my name is",
|
||||
|
||||
@ -42,8 +42,6 @@ from tensorrt_llm._torch.modules.fused_moe import (
|
||||
from tensorrt_llm._torch.modules.fused_moe.quantization import \
|
||||
NVFP4CutlassFusedMoEMethod
|
||||
# isort: on
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._utils import get_sm_version, mpi_disabled, mpi_rank
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
@ -93,8 +91,6 @@ def test_fused_moe(moe_backend,
|
||||
mapping=None):
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
if dtype != torch.bfloat16:
|
||||
pytest.skip("Unsupported for TritonFusedMoE")
|
||||
if routing_cls != RenormalizeMoeRoutingMethod:
|
||||
@ -528,8 +524,6 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
|
||||
def test_fused_moe_fp8(moe_backend, dtype, routing_cls, bias):
|
||||
|
||||
if moe_backend == "TRITON":
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
if dtype != torch.bfloat16:
|
||||
pytest.skip("Unsupported for TritonFusedMoE")
|
||||
if routing_cls != RenormalizeMoeRoutingMethod:
|
||||
@ -2548,8 +2542,6 @@ def test_fused_moe_wfp4a16(dtype, hidden_size, moe_backend,
|
||||
@pytest.mark.parametrize("dynamic_quant", [True, False])
|
||||
def test_fused_moe_triton_mxfp4(experts, hidden_size, intermediate_size,
|
||||
fp8_activation, bias, dynamic_quant):
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
if torch.cuda.get_device_capability()[0] < 10 and fp8_activation:
|
||||
pytest.skip("Latest Triton requires BF16 activation on Hopper")
|
||||
if torch.cuda.get_device_capability()[0] >= 10 and not fp8_activation:
|
||||
|
||||
@ -7,8 +7,6 @@ import torch
|
||||
from mpi4py import MPI
|
||||
from utils.util import check_accuracy, skip_pre_hopper
|
||||
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
|
||||
IS_TRITON_KERNELS_AVAILABLE
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._torch.modules.triton_linear import TritonLinear
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
@ -23,9 +21,6 @@ MPI.pickle.__init__(
|
||||
|
||||
@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear])
|
||||
def test_linear_unquantized(linear_cls):
|
||||
if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
num_tokens = 128
|
||||
@ -58,9 +53,6 @@ def test_linear_unquantized(linear_cls):
|
||||
|
||||
@pytest.mark.parametrize("linear_cls", [Linear, TritonLinear])
|
||||
def test_linear_fp8qdq(linear_cls):
|
||||
if not IS_TRITON_KERNELS_AVAILABLE and linear_cls is TritonLinear:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
num_tokens = 128
|
||||
@ -104,8 +96,6 @@ def test_linear_fp8qdq(linear_cls):
|
||||
@pytest.mark.parametrize("activation_dtype",
|
||||
[torch.bfloat16, torch.float8_e4m3fn])
|
||||
def test_linear_mxfp4(activation_dtype):
|
||||
if not IS_TRITON_KERNELS_AVAILABLE:
|
||||
pytest.skip("Triton kernels are not available")
|
||||
if torch.cuda.get_device_capability(
|
||||
)[0] < 10 and activation_dtype == torch.float8_e4m3fn:
|
||||
pytest.skip("Latest Triton requires BF16 activation on Hopper")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user