mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
470 lines
14 KiB
Python
470 lines
14 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import functools
|
|
import subprocess
|
|
from collections import namedtuple
|
|
from itertools import product
|
|
from typing import Callable, List
|
|
|
|
import pytest
|
|
|
|
field2arg = {
|
|
'seq_len': '-s',
|
|
'seq_len_q': '-s-q', # fmhca specific
|
|
'seq_len_kv': '-s-kv', # fmhca specific
|
|
'min_seq_len': '-min-s',
|
|
'head_dim': '-d',
|
|
'batch': '-b',
|
|
'num_head': '-h',
|
|
'fp16': '-fp16',
|
|
'bf16': '-bf16',
|
|
'fp16_fp32': '-fp16-fp32',
|
|
'int8': '-int8',
|
|
'e4m3': '-e4m3',
|
|
'use_interleaved': '-il',
|
|
'use_tma': '-use-tma',
|
|
'force_non_warp_specialization': '-force-non-warp-specialization',
|
|
'force_non_flash_attention': '-force-non-flash-attention'
|
|
}
|
|
|
|
fields_fmha = [
|
|
'seq_len', 'min_seq_len', 'head_dim', 'batch', 'num_head', 'precision',
|
|
'use_interleaved', 'use_tma', 'force_non_warp_specialization',
|
|
'force_non_flash_attention'
|
|
]
|
|
|
|
fields_fmhca = [
|
|
'seq_len_q',
|
|
'seq_len_kv',
|
|
'min_seq_len',
|
|
'head_dim',
|
|
'batch',
|
|
'num_head',
|
|
'precision',
|
|
'use_interleaved',
|
|
]
|
|
|
|
FmhaArgs = namedtuple('FmhaArgs',
|
|
fields_fmha,
|
|
defaults=(256, 1, 64, 1, 1, 'fp16', False, False, True,
|
|
False))
|
|
FmhcaArgs = namedtuple('FmhcaArgs',
|
|
fields_fmhca,
|
|
defaults=(256, 256, 1, 64, 1, 1, 'fp16', False))
|
|
|
|
|
|
# custom test name
|
|
def idfn(x: List[Callable] or Callable) -> str:
|
|
if isinstance(x, List):
|
|
if len(x) > 0:
|
|
return ".".join([i.__name__ for i in x])
|
|
else:
|
|
return "default"
|
|
else:
|
|
return x.__name__
|
|
|
|
|
|
def combinations_base():
|
|
seq_lens = [
|
|
32, 64, 96, 128, 192, 256, 384, 512, 1024, 2048, 4096, 8192, 16384,
|
|
32768
|
|
]
|
|
head_dims = [16, 32, 40, 64, 80, 128, 160, 256, 512]
|
|
min_seq_lens = [1]
|
|
num_heads = [4]
|
|
batches = [3]
|
|
precision = ['fp16', 'bf16', 'fp16_fp32']
|
|
use_interleaved = [False]
|
|
use_tma = [False, True]
|
|
force_non_warp_specialization = [False, True]
|
|
force_non_flash_attention = [False, True]
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in
|
|
product(seq_lens, min_seq_lens, head_dims, batches, num_heads,
|
|
precision, use_interleaved, use_tma,
|
|
force_non_warp_specialization, force_non_flash_attention)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(min_seq_len=fmha_arg.seq_len)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment1
|
|
|
|
|
|
def combinations_fp16():
|
|
fmha_args = combinations_base()
|
|
return [fmha_arg._replace(precision='fp16') for fmha_arg in fmha_args]
|
|
|
|
|
|
def reduced_combinations_fp16():
|
|
fmha_args = combinations_fp16()
|
|
return [
|
|
fmha_arg for fmha_arg in fmha_args
|
|
if fmha_arg.seq_len in [128, 256, 512]
|
|
]
|
|
|
|
|
|
def reduced2x_combinations_fp16():
|
|
fmha_args = combinations_fp16()
|
|
return [
|
|
fmha_arg for fmha_arg in fmha_args if
|
|
fmha_arg.seq_len in [128, 256, 512] and fmha_arg.head_dim in [32, 64]
|
|
]
|
|
|
|
|
|
def combinations_bf16():
|
|
fmha_args = combinations_base()
|
|
return [fmha_arg._replace(precision='bf16') for fmha_arg in fmha_args]
|
|
|
|
|
|
def reduced_combinations_bf16():
|
|
fmha_args = combinations_bf16()
|
|
return [
|
|
fmha_arg for fmha_arg in fmha_args
|
|
if fmha_arg.seq_len in [128, 256, 512]
|
|
]
|
|
|
|
|
|
def reduced2x_combinations_bf16():
|
|
fmha_args = combinations_bf16()
|
|
return [
|
|
fmha_arg for fmha_arg in fmha_args if
|
|
fmha_arg.seq_len in [128, 256, 512] and fmha_arg.head_dim in [32, 64]
|
|
]
|
|
|
|
|
|
def combinations_int8():
|
|
seq_lens = [32, 64, 96, 128, 192, 256, 384, 512]
|
|
head_dims = [16, 32, 64]
|
|
min_seq_lens = [1]
|
|
num_heads = [2]
|
|
batches = [3]
|
|
precision = ['int8']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in product(seq_lens, min_seq_lens, head_dims,
|
|
batches, num_heads, precision)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(min_seq_len=fmha_arg.seq_len)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment1
|
|
|
|
|
|
def combinations_fp16_bench():
|
|
seq_lens = [512]
|
|
head_dims = [64, 128, 256, 512]
|
|
min_seq_lens = [512]
|
|
num_heads = [16]
|
|
batches = [16]
|
|
precision = ['fp16']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in product(seq_lens, min_seq_lens, head_dims,
|
|
batches, num_heads, precision)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(seq_len=1024, min_seq_len=1024, batch=8)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
fmha_args_increment2 = fmha_args_increment1 + [
|
|
fmha_arg._replace(seq_len=2048, min_seq_len=2048, batch=4)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
fmha_args_increment3 = fmha_args_increment2 + [
|
|
fmha_arg._replace(seq_len=4096, min_seq_len=4096, batch=2)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
fmha_args_increment4 = fmha_args_increment3 + [
|
|
fmha_arg._replace(seq_len=32768, min_seq_len=32768, batch=1)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment4
|
|
|
|
|
|
def combinations_fp16_sd_bench():
|
|
seq_lens = [4096]
|
|
min_seq_lens = [4096]
|
|
head_dims = [40]
|
|
num_heads = [8]
|
|
batches = [2, 4, 8, 16, 32]
|
|
precision = ['fp16']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in product(seq_lens, min_seq_lens, head_dims,
|
|
batches, num_heads, precision)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(seq_len=1024, min_seq_len=1024, head_dim=80)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
fmha_args_increment2 = fmha_args_increment1 + [
|
|
fmha_arg._replace(seq_len=256, min_seq_len=256, head_dim=160)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
fmha_args_increment3 = fmha_args_increment2 + [
|
|
fmha_arg._replace(seq_len=64, min_seq_len=64, head_dim=160)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment3
|
|
|
|
|
|
def combinations_fmhca():
|
|
"""
|
|
bin/fmhca.exe -b 1 -s-q 4096 -min-s 4096 -d 40
|
|
bin/fmhca.exe -b 1 -s-q 4096 -min-s 4096 -d 80
|
|
bin/fmhca.exe -b 1 -s-q 4096 -min-s 4096 -d 160
|
|
|
|
bin/fmhca.exe -b 4 -s-q 4096 -min-s 4096 -d 40
|
|
bin/fmhca.exe -b 4 -s-q 4096 -min-s 4096 -d 80
|
|
bin/fmhca.exe -b 4 -s-q 4096 -min-s 4096 -d 160
|
|
|
|
bin/fmhca.exe -b 1 -s-q 2304 -min-s 2304 -d 40
|
|
bin/fmhca.exe -b 1 -s-q 2304 -min-s 2304 -d 80
|
|
bin/fmhca.exe -b 1 -s-q 2304 -min-s 2304 -d 160
|
|
|
|
bin/fmhca.exe -b 4 -s-q 2304 -min-s 2304 -d 40
|
|
bin/fmhca.exe -b 4 -s-q 2304 -min-s 2304 -d 80
|
|
bin/fmhca.exe -b 4 -s-q 2304 -min-s 2304 -d 160
|
|
|
|
bin/fmhca.exe -b 1 -s-q 1024 -min-s 1024 -d 40
|
|
bin/fmhca.exe -b 1 -s-q 1024 -min-s 1024 -d 80
|
|
bin/fmhca.exe -b 1 -s-q 1024 -min-s 1024 -d 160
|
|
|
|
bin/fmhca.exe -b 4 -s-q 1024 -min-s 1024 -d 40
|
|
bin/fmhca.exe -b 4 -s-q 1024 -min-s 1024 -d 80
|
|
bin/fmhca.exe -b 4 -s-q 1024 -min-s 1024 -d 160
|
|
"""
|
|
seq_len_qs = [1024, 2304, 4096]
|
|
seq_len_kvs = [77] # ?
|
|
min_seq_len = [1]
|
|
head_dims = [40, 80, 160]
|
|
num_heads = [16]
|
|
batches = [1, 4]
|
|
precision = ['fp16']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhcaArgs(*combo) for combo in product(
|
|
seq_len_qs,
|
|
seq_len_kvs,
|
|
min_seq_len,
|
|
head_dims,
|
|
batches,
|
|
num_heads,
|
|
precision,
|
|
)
|
|
]
|
|
|
|
# min_seq_len = seq_len
|
|
fmha_args_increment1 = [
|
|
fmha_arg._replace(min_seq_len=fmha_arg.seq_len_q)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment1
|
|
|
|
|
|
def combinations_e4m3():
|
|
"""
|
|
bin/fmha.exe -v 0 -runs 1 -s 512 -d 64 -min-s 1 -b 1 -h 1 -e4m3
|
|
bin/fmha.exe -v 0 -runs 1 -s 384 -d 64 -min-s 1 -b 1 -h 1 -e4m3
|
|
bin/fmha.exe -v 0 -runs 1 -s 256 -d 64 -min-s 1 -b 1 -h 1 -e4m3
|
|
bin/fmha.exe -v 0 -runs 1 -s 128 -d 64 -min-s 1 -b 1 -h 1 -e4m3
|
|
|
|
bin/fmha.exe -v 0 -runs 1 -s 512 -d 64 -min-s 1 -b 1 -h 1 -e4m3 -ignore-b1opt
|
|
bin/fmha.exe -v 0 -runs 1 -s 384 -d 64 -min-s 1 -b 1 -h 1 -e4m3 -ignore-b1opt
|
|
bin/fmha.exe -v 0 -runs 1 -s 256 -d 64 -min-s 1 -b 1 -h 1 -e4m3 -ignore-b1opt
|
|
bin/fmha.exe -v 0 -runs 1 -s 128 -d 64 -min-s 1 -b 1 -h 1 -e4m3 -ignore-b1opt
|
|
"""
|
|
seq_lens = [128, 256, 384, 512]
|
|
head_dims = [64]
|
|
min_seq_lens = [1]
|
|
num_heads = [2]
|
|
batches = [3]
|
|
precision = ['e4m3']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in product(seq_lens, min_seq_lens, head_dims,
|
|
batches, num_heads, precision)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(min_seq_len=fmha_arg.seq_len)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment1
|
|
|
|
|
|
def combinations_int8_interleaved():
|
|
"""
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 128 -d 64 -min-s 128 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 128 -d 64 -min-s 128 -b 128
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 128 -d 64 -min-s 1 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 128 -d 64 -min-s 1 -b 128
|
|
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 192 -d 64 -min-s 192 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 192 -d 64 -min-s 192 -b 128
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 192 -d 64 -min-s 1 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 192 -d 64 -min-s 1 -b 128
|
|
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 256 -d 64 -min-s 256 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 256 -d 64 -min-s 256 -b 128
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 256 -d 64 -min-s 1 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 256 -d 64 -min-s 1 -b 128
|
|
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 384 -d 64 -min-s 384 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 384 -d 64 -min-s 384 -b 128
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 384 -d 64 -min-s 1 -b 1
|
|
bin/fmha.exe -v 0 -runs 1 -il -int8 -s 384 -d 64 -min-s 1 -b 128
|
|
"""
|
|
seq_lens = [128, 192, 256, 384]
|
|
head_dims = [64]
|
|
min_seq_lens = [1]
|
|
num_heads = [1, 16]
|
|
batches = [1, 128]
|
|
precision = ['int8']
|
|
use_interleaved = [True]
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo)
|
|
for combo in product(seq_lens, min_seq_lens, head_dims, batches,
|
|
num_heads, precision, use_interleaved)
|
|
]
|
|
|
|
# + min_seq_len = seq_len
|
|
fmha_args_increment1 = fmha_args + [
|
|
fmha_arg._replace(min_seq_len=fmha_arg.seq_len)
|
|
for fmha_arg in fmha_args
|
|
]
|
|
|
|
return fmha_args_increment1
|
|
|
|
|
|
def combinations_small():
|
|
seq_lens = [4096]
|
|
head_dims = [64]
|
|
min_seq_lens = [1]
|
|
num_heads = [1]
|
|
batches = [1, 512]
|
|
precision = ['fp16']
|
|
|
|
# base combination
|
|
fmha_args = [
|
|
FmhaArgs(*combo) for combo in product(seq_lens, min_seq_lens, head_dims,
|
|
batches, num_heads, precision)
|
|
]
|
|
|
|
return fmha_args
|
|
|
|
|
|
def base_command(exe_path):
|
|
return [exe_path, '-v', '0', '-runs', '1']
|
|
|
|
|
|
def apply_rule(rule):
|
|
"""
|
|
decorator for tests which accepts rule f(fmha_arg, **kwargs) as argument to filter out specific
|
|
combinations of arguments and kwargs
|
|
"""
|
|
|
|
def apply_rule_(fmha_harness):
|
|
# make wrapper looks like original fmha_harness to avoid interference with pytest inner workings
|
|
@functools.wraps(fmha_harness)
|
|
def fmha_harness_wrapper(**kwargs):
|
|
# rules (dtype = pytest.fixture) is mutable; deepcopy to avoid changing test states
|
|
rules_copy = copy.deepcopy(kwargs.get('rules', []))
|
|
# if disable_rules exists and is False, apply rules
|
|
# if it somehow does not exist, assume we want to apply rules
|
|
try:
|
|
if not kwargs['disable_rules']:
|
|
rules_copy.append(rule)
|
|
except:
|
|
rules_copy.append(rule)
|
|
kwargs_copy = copy.deepcopy(kwargs)
|
|
kwargs_copy['rules'] = rules_copy
|
|
return fmha_harness(**kwargs_copy)
|
|
|
|
return fmha_harness_wrapper
|
|
|
|
return apply_rule_
|
|
|
|
|
|
def sanitize_prompt(prompt):
|
|
return [l for l in prompt if l != '']
|
|
|
|
|
|
def fmha_harness(exe_path, fmha_arg, rules=[], dryrun=False, **kwargs):
|
|
"""
|
|
exe_path: path to executable
|
|
fmha_arg: arguments to pass the executable
|
|
rules: a list of functionals f(fmha_arg, **kwargs) that accepts fmha_arg and additional argument
|
|
for filtering out specific inputs
|
|
dryrun: print command line without actually invoking it
|
|
**kwargs: optional kwargs to pass to rules
|
|
"""
|
|
# print(str(fmha_arg))
|
|
prompt = base_command(exe_path)
|
|
for rule in rules:
|
|
rule_added_prompt = rule(fmha_arg, **kwargs)
|
|
prompt += rule_added_prompt if rule_added_prompt is not None else ""
|
|
for k, v in fmha_arg._asdict().items():
|
|
if k == 'precision':
|
|
prompt += [
|
|
field2arg[v]
|
|
] # kv pair (precision, dtype) maps to -dtype in command line
|
|
elif k == 'use_interleaved' or k == 'use_tma' or k == 'force_non_warp_specialization' \
|
|
or k == 'force_non_flash_attention':
|
|
if v is True:
|
|
prompt += [
|
|
field2arg[k]
|
|
] # kv pair (key, true) maps to field2arg in command line
|
|
else:
|
|
prompt += [field2arg[k]]
|
|
prompt += [str(fmha_arg._asdict()[k])]
|
|
prompt = sanitize_prompt(prompt)
|
|
print(f'Full prompt: "{" ".join(prompt)}"')
|
|
if not dryrun:
|
|
try:
|
|
subprocess.run(prompt, check=True)
|
|
except subprocess.CalledProcessError as e:
|
|
pytest.fail(
|
|
f'Exception caught during subprocess call: "{" ".join(prompt)}" returns {e.returncode}'
|
|
)
|