mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
test: Accuracy test improvement (Part 3.1): Extend accuracy test suite with LLM API and initial implementation of trtllm-eval (#3167)
* add eval_llmapi Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> tmp commit port to CLI tool Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> setup llmapi Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix spec_dec_algo Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> _update_from_hf_quant_config Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> migrate test_pytorch.py Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix fp8 block scales Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> fix fp8 rowwise Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> adj alpha Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move test_pytorch.py cases Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> move Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> rename test_accuracy.py to test_cli.py Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> clean Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix cnn_dailymail Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * renaming to cli flow Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * rename MMLU Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * rename Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * add error Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> * fix Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com> --------- Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
bf02b9144f
commit
b2f69db507
@ -549,10 +549,10 @@ def main(args):
|
||||
if runtime_rank == 0 and args.eval_task != "eval_context_ppl":
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
logger.info("TensorRT-LLM Generated : ")
|
||||
logger.info(f" Input : {datapoint[dataset_input_key]}")
|
||||
logger.info(f"\n Reference : {datapoint[dataset_output_key]}")
|
||||
logger.info(f"\n Output : {output}")
|
||||
logger.info("TensorRT-LLM Generated: ")
|
||||
logger.info(f" Input: {datapoint[dataset_input_key]}")
|
||||
logger.info(f"\n Reference: {datapoint[dataset_output_key]}")
|
||||
logger.info(f"\n Output: {output}")
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
|
||||
@ -609,9 +609,9 @@ def main(args):
|
||||
)
|
||||
|
||||
logger.debug('-' * 100)
|
||||
logger.debug(f"Input : {datapoint[dataset_input_key]}")
|
||||
logger.debug(f"Input: {datapoint[dataset_input_key]}")
|
||||
logger.debug(f'TensorRT-LLM Output: {output_tensorrt_llm}')
|
||||
logger.debug(f"Reference : {datapoint[dataset_output_key]}")
|
||||
logger.debug(f"Reference: {datapoint[dataset_output_key]}")
|
||||
|
||||
data_point_idx += max_batch_size
|
||||
ite_count += 1
|
||||
@ -666,10 +666,10 @@ def main(args):
|
||||
if runtime_rank == 0 and args.eval_task != "eval_context_ppl":
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
logger.info("HF Generated : ")
|
||||
logger.info(f" Input : {datapoint[dataset_input_key]}")
|
||||
logger.info(f"\n Reference : {datapoint[dataset_output_key]}")
|
||||
logger.info(f"\n Output : {output}")
|
||||
logger.info("HF Generated: ")
|
||||
logger.info(f" Input: {datapoint[dataset_input_key]}")
|
||||
logger.info(f"\n Reference: {datapoint[dataset_output_key]}")
|
||||
logger.info(f"\n Output: {output}")
|
||||
logger.info(
|
||||
"---------------------------------------------------------")
|
||||
|
||||
@ -722,9 +722,9 @@ def main(args):
|
||||
)
|
||||
|
||||
logger.debug('-' * 100)
|
||||
logger.debug(f"Input : {datapoint[dataset_input_key]}")
|
||||
logger.debug(f"Input: {datapoint[dataset_input_key]}")
|
||||
logger.debug(f'HF Output: {output_hf}')
|
||||
logger.debug(f"Reference : {datapoint[dataset_output_key]}")
|
||||
logger.debug(f"Reference: {datapoint[dataset_output_key]}")
|
||||
|
||||
data_point_idx += max_batch_size
|
||||
ite_count += 1
|
||||
@ -761,14 +761,14 @@ def main(args):
|
||||
}
|
||||
for key in computed_metrics_tensorrt_llm.keys():
|
||||
logger.info(
|
||||
f" {key} : {computed_metrics_tensorrt_llm[key]*100} ({computed_std_dev_tensorrt_llm[key]*100})"
|
||||
f" {key}: {computed_metrics_tensorrt_llm[key]*100} ({computed_std_dev_tensorrt_llm[key]*100})"
|
||||
)
|
||||
else:
|
||||
computed_metrics_tensorrt_llm = metric_tensorrt_llm[
|
||||
beam_idx].compute()
|
||||
for key in computed_metrics_tensorrt_llm.keys():
|
||||
logger.info(
|
||||
f" {key} : {computed_metrics_tensorrt_llm[key]*100}"
|
||||
f" {key}: {computed_metrics_tensorrt_llm[key]*100}"
|
||||
)
|
||||
if args.check_accuracy and beam_idx == 0:
|
||||
rouge1 = computed_metrics_tensorrt_llm['rouge1'] * 100
|
||||
@ -797,7 +797,7 @@ def main(args):
|
||||
computed_metrics_hf = metric_hf[beam_idx].compute()
|
||||
if args.eval_task != "eval_context_ppl":
|
||||
for key in computed_metrics_hf.keys():
|
||||
logger.info(f' {key} : {computed_metrics_hf[key]*100}')
|
||||
logger.info(f' {key}: {computed_metrics_hf[key]*100}')
|
||||
if args.eval_ppl and args.batch_size == 1:
|
||||
logger.info(
|
||||
f" Per-token perplexity: {np.mean(ppls_hf[beam_idx])}")
|
||||
|
||||
1
setup.py
1
setup.py
@ -236,6 +236,7 @@ setup(
|
||||
'trtllm-refit=tensorrt_llm.commands.refit:main',
|
||||
'trtllm-bench=tensorrt_llm.commands.bench:main',
|
||||
'trtllm-serve=tensorrt_llm.commands.serve:main',
|
||||
'trtllm-eval=tensorrt_llm.commands.eval:main'
|
||||
],
|
||||
},
|
||||
scripts=['tensorrt_llm/llmapi/trtllm-llmapi-launch'],
|
||||
|
||||
154
tensorrt_llm/commands/eval.py
Normal file
154
tensorrt_llm/commands/eval.py
Normal file
@ -0,0 +1,154 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
import tensorrt_llm.profiler as profiler
|
||||
|
||||
from .._torch.llm import LLM as PyTorchLLM
|
||||
from .._torch.pyexecutor.config import PyTorchConfig
|
||||
from ..evaluate import MMLU, CnnDailymail
|
||||
from ..llmapi import LLM, BuildConfig, KvCacheConfig
|
||||
from ..llmapi.llm_utils import update_llm_args_with_extra_options
|
||||
from ..logger import logger, severity_map
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"--model",
|
||||
required=True,
|
||||
type=str,
|
||||
help="model name | HF checkpoint path | TensorRT engine path",
|
||||
)
|
||||
@click.option("--tokenizer",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path | Name of the tokenizer."
|
||||
"Specify this value only if using TensorRT engine as model.")
|
||||
@click.option("--backend",
|
||||
type=click.Choice(["pytorch", "tensorrt"]),
|
||||
default="pytorch",
|
||||
help="Set to 'pytorch' for pytorch path. Default is cpp path.")
|
||||
@click.option('--log_level',
|
||||
type=click.Choice(severity_map.keys()),
|
||||
default='info',
|
||||
help="The logging level.")
|
||||
@click.option("--max_beam_width",
|
||||
type=int,
|
||||
default=BuildConfig.max_beam_width,
|
||||
help="Maximum number of beams for beam search decoding.")
|
||||
@click.option("--max_batch_size",
|
||||
type=int,
|
||||
default=BuildConfig.max_batch_size,
|
||||
help="Maximum number of requests that the engine can schedule.")
|
||||
@click.option(
|
||||
"--max_num_tokens",
|
||||
type=int,
|
||||
default=BuildConfig.max_num_tokens,
|
||||
help=
|
||||
"Maximum number of batched input tokens after padding is removed in each batch."
|
||||
)
|
||||
@click.option(
|
||||
"--max_seq_len",
|
||||
type=int,
|
||||
default=BuildConfig.max_seq_len,
|
||||
help="Maximum total length of one request, including prompt and outputs. "
|
||||
"If unspecified, the value is deduced from the model config.")
|
||||
@click.option("--tp_size", type=int, default=1, help='Tensor parallelism size.')
|
||||
@click.option("--pp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help='Pipeline parallelism size.')
|
||||
@click.option("--ep_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="expert parallelism size")
|
||||
@click.option("--gpus_per_node",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of GPUs per node. Default to None, and it will be "
|
||||
"detected automatically.")
|
||||
@click.option("--kv_cache_free_gpu_memory_fraction",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Free GPU memory fraction reserved for KV Cache, "
|
||||
"after allocating model weights and buffers.")
|
||||
@click.option("--trust_remote_code",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Flag for HF transformers.")
|
||||
@click.option("--extra_llm_api_options",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a YAML file that overwrites the parameters")
|
||||
@click.pass_context
|
||||
def main(ctx, model: str, tokenizer: Optional[str], log_level: str,
|
||||
backend: str, max_beam_width: int, max_batch_size: int,
|
||||
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
|
||||
ep_size: Optional[int], gpus_per_node: Optional[int],
|
||||
kv_cache_free_gpu_memory_fraction: float, trust_remote_code: bool,
|
||||
extra_llm_api_options: Optional[str]):
|
||||
logger.set_level(log_level)
|
||||
build_config = BuildConfig(max_batch_size=max_batch_size,
|
||||
max_num_tokens=max_num_tokens,
|
||||
max_beam_width=max_beam_width,
|
||||
max_seq_len=max_seq_len)
|
||||
|
||||
kv_cache_config = KvCacheConfig(
|
||||
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction)
|
||||
|
||||
if backend == "tensorrt":
|
||||
backend = None
|
||||
pytorch_backend_config = None
|
||||
if backend == "pytorch":
|
||||
pytorch_backend_config = PyTorchConfig(enable_overlap_scheduler=True)
|
||||
|
||||
llm_args = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"tensor_parallel_size": tp_size,
|
||||
"pipeline_parallel_size": pp_size,
|
||||
"moe_expert_parallel_size": ep_size,
|
||||
"gpus_per_node": gpus_per_node,
|
||||
"trust_remote_code": trust_remote_code,
|
||||
"build_config": build_config,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
"backend": backend,
|
||||
"pytorch_backend_config": pytorch_backend_config,
|
||||
}
|
||||
|
||||
if extra_llm_api_options is not None:
|
||||
llm_args = update_llm_args_with_extra_options(llm_args,
|
||||
extra_llm_api_options)
|
||||
|
||||
profiler.start("trtllm init")
|
||||
if backend == 'pytorch':
|
||||
llm = PyTorchLLM(**llm_args)
|
||||
else:
|
||||
llm = LLM(**llm_args)
|
||||
profiler.stop("trtllm init")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm init")
|
||||
logger.info(f"TRTLLM initialization time: {elapsed_time:.3f} seconds.")
|
||||
|
||||
# Pass llm to subcommands
|
||||
ctx.obj = llm
|
||||
|
||||
|
||||
main.add_command(CnnDailymail.command)
|
||||
main.add_command(MMLU.command)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
tensorrt_llm/evaluate/__init__.py
Executable file
19
tensorrt_llm/evaluate/__init__.py
Executable file
@ -0,0 +1,19 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .cnn_dailymail import CnnDailymail
|
||||
from .mmlu import MMLU
|
||||
|
||||
__all__ = ["CnnDailymail", "MMLU"]
|
||||
90
tensorrt_llm/evaluate/cnn_dailymail.py
Normal file
90
tensorrt_llm/evaluate/cnn_dailymail.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import click
|
||||
import datasets
|
||||
import evaluate
|
||||
|
||||
from .._torch import LLM as PyTorchLLM
|
||||
from ..llmapi import LLM, RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator
|
||||
|
||||
|
||||
class CnnDailymail(Evaluator):
|
||||
|
||||
def __init__(self,
|
||||
dataset_path: str = "ccdv/cnn_dailymail",
|
||||
num_samples: int = None,
|
||||
random_seed: int = 0,
|
||||
rouge_path: str = "rouge"):
|
||||
self.data = datasets.load_dataset(dataset_path, "3.0.0", split="test")
|
||||
self.data = self.data.shuffle(random_seed)
|
||||
if num_samples is None:
|
||||
self.num_samples = self.data.num_rows
|
||||
else:
|
||||
self.num_samples = min(num_samples, self.data.num_rows)
|
||||
self.rouge = evaluate.load(rouge_path)
|
||||
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
for i, sample in enumerate(self.data):
|
||||
if i >= self.num_samples:
|
||||
break
|
||||
prompt = sample["article"] + " TL;DR:"
|
||||
prompt = prompt.strip().replace(" n't", "n't")
|
||||
yield prompt, sample["highlights"]
|
||||
|
||||
def compute_score(self, outputs: List[RequestOutput],
|
||||
references: List[str]) -> float:
|
||||
for beam_idx in range(len(outputs[0].outputs)):
|
||||
metrics = self.rouge.compute(
|
||||
predictions=[output.outputs[0].text for output in outputs],
|
||||
references=references)
|
||||
logger.info(f"Beam {beam_idx} rouge scores:")
|
||||
for key in metrics.keys():
|
||||
logger.info(f"\t{key}: {metrics[key]*100:.3f}")
|
||||
if beam_idx == 0:
|
||||
rouge1 = metrics["rouge1"] * 100
|
||||
return rouge1
|
||||
|
||||
@click.command("cnn_dailymail")
|
||||
@click.option("--dataset_path", type=str, default="ccdv/cnn_dailymail")
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--rouge_path", type=str, default="rouge")
|
||||
@click.option("--max_input_length", type=int, default=924)
|
||||
@click.option("--max_output_length", type=int, default=100)
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=15)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, num_samples: int, random_seed: int,
|
||||
rouge_path: str, max_input_length: int, max_output_length: int,
|
||||
check_accuracy: bool, accuracy_threshold: float) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
truncate_prompt_tokens=max_input_length)
|
||||
evaluator = CnnDailymail(dataset_path,
|
||||
num_samples=num_samples,
|
||||
random_seed=random_seed,
|
||||
rouge_path=rouge_path)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
if check_accuracy:
|
||||
assert accuracy >= accuracy_threshold, f"Expected accuracy >= {accuracy_threshold}, but got {accuracy}"
|
||||
61
tensorrt_llm/evaluate/interface.py
Normal file
61
tensorrt_llm/evaluate/interface.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod, abstractstaticmethod
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import tensorrt_llm.profiler as profiler
|
||||
|
||||
from .._torch import LLM as PyTorchLLM
|
||||
from ..llmapi import LLM, RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Evaluator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def compute_score(self, outputs: List[RequestOutput], references: List[str],
|
||||
*auxiliaries) -> float:
|
||||
raise NotImplementedError()
|
||||
|
||||
def evaluate(self,
|
||||
llm: Union[LLM, PyTorchLLM],
|
||||
sampling_params: Optional[SamplingParams] = None) -> float:
|
||||
profiler.start("trtllm exec")
|
||||
outputs, references, auxiliaries = [], [], []
|
||||
for prompt, reference, *aux in tqdm(self.generate_samples(),
|
||||
desc="Submitting requests"):
|
||||
output = llm.generate_async(prompt, sampling_params)
|
||||
outputs.append(output)
|
||||
references.append(reference)
|
||||
auxiliaries.append(aux)
|
||||
for output in tqdm(outputs, desc="Fetching responses"):
|
||||
output.result()
|
||||
profiler.stop("trtllm exec")
|
||||
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
|
||||
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
|
||||
|
||||
score = self.compute_score(outputs, references, *zip(*auxiliaries))
|
||||
return score
|
||||
|
||||
@abstractstaticmethod
|
||||
def command(ctx, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
247
tensorrt_llm/evaluate/mmlu.py
Normal file
247
tensorrt_llm/evaluate/mmlu.py
Normal file
@ -0,0 +1,247 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
import random
|
||||
from typing import Iterable, List, Union
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .._torch import LLM as PyTorchLLM
|
||||
from ..llmapi import LLM, RequestOutput
|
||||
from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from .interface import Evaluator
|
||||
|
||||
|
||||
class MMLU(Evaluator):
|
||||
CHOICES = ["A", "B", "C", "D"]
|
||||
SUBJECT_TO_SUBCATEGORIES = {
|
||||
"abstract_algebra": ["math"],
|
||||
"anatomy": ["health"],
|
||||
"astronomy": ["physics"],
|
||||
"business_ethics": ["business"],
|
||||
"clinical_knowledge": ["health"],
|
||||
"college_biology": ["biology"],
|
||||
"college_chemistry": ["chemistry"],
|
||||
"college_computer_science": ["computer science"],
|
||||
"college_mathematics": ["math"],
|
||||
"college_medicine": ["health"],
|
||||
"college_physics": ["physics"],
|
||||
"computer_security": ["computer science"],
|
||||
"conceptual_physics": ["physics"],
|
||||
"econometrics": ["economics"],
|
||||
"electrical_engineering": ["engineering"],
|
||||
"elementary_mathematics": ["math"],
|
||||
"formal_logic": ["philosophy"],
|
||||
"global_facts": ["other"],
|
||||
"high_school_biology": ["biology"],
|
||||
"high_school_chemistry": ["chemistry"],
|
||||
"high_school_computer_science": ["computer science"],
|
||||
"high_school_european_history": ["history"],
|
||||
"high_school_geography": ["geography"],
|
||||
"high_school_government_and_politics": ["politics"],
|
||||
"high_school_macroeconomics": ["economics"],
|
||||
"high_school_mathematics": ["math"],
|
||||
"high_school_microeconomics": ["economics"],
|
||||
"high_school_physics": ["physics"],
|
||||
"high_school_psychology": ["psychology"],
|
||||
"high_school_statistics": ["math"],
|
||||
"high_school_us_history": ["history"],
|
||||
"high_school_world_history": ["history"],
|
||||
"human_aging": ["health"],
|
||||
"human_sexuality": ["culture"],
|
||||
"international_law": ["law"],
|
||||
"jurisprudence": ["law"],
|
||||
"logical_fallacies": ["philosophy"],
|
||||
"machine_learning": ["computer science"],
|
||||
"management": ["business"],
|
||||
"marketing": ["business"],
|
||||
"medical_genetics": ["health"],
|
||||
"miscellaneous": ["other"],
|
||||
"moral_disputes": ["philosophy"],
|
||||
"moral_scenarios": ["philosophy"],
|
||||
"nutrition": ["health"],
|
||||
"philosophy": ["philosophy"],
|
||||
"prehistory": ["history"],
|
||||
"professional_accounting": ["other"],
|
||||
"professional_law": ["law"],
|
||||
"professional_medicine": ["health"],
|
||||
"professional_psychology": ["psychology"],
|
||||
"public_relations": ["politics"],
|
||||
"security_studies": ["politics"],
|
||||
"sociology": ["culture"],
|
||||
"us_foreign_policy": ["politics"],
|
||||
"virology": ["health"],
|
||||
"world_religions": ["philosophy"],
|
||||
}
|
||||
CATEGORY_TO_SUBCATEGORIES = {
|
||||
"STEM": [
|
||||
"physics",
|
||||
"chemistry",
|
||||
"biology",
|
||||
"computer science",
|
||||
"math",
|
||||
"engineering",
|
||||
],
|
||||
"humanities": ["history", "philosophy", "law"],
|
||||
"social sciences": [
|
||||
"politics",
|
||||
"culture",
|
||||
"economics",
|
||||
"geography",
|
||||
"psychology",
|
||||
],
|
||||
"other (business, health, misc.)": ["other", "business", "health"],
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
dataset_path: str,
|
||||
num_samples: int = None,
|
||||
num_train: int = 5,
|
||||
random_seed: int = 0):
|
||||
self.dataset_path = dataset_path
|
||||
if num_samples is None:
|
||||
self.num_samples_per_subject = None
|
||||
else:
|
||||
self.num_samples_per_subject = math.ceil(
|
||||
num_samples / len(self.SUBJECT_TO_SUBCATEGORIES))
|
||||
self.num_train = num_train
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
def format_subject(self, subject):
|
||||
line = subject.split("_")
|
||||
s = ""
|
||||
for entry in line:
|
||||
s += " " + entry
|
||||
return s
|
||||
|
||||
def format_example(self, df, idx, include_answer=True):
|
||||
prompt = df.iloc[idx, 0]
|
||||
k = df.shape[1] - 2
|
||||
for j in range(k):
|
||||
prompt += "\n{}. {}".format(self.CHOICES[j], df.iloc[idx, j + 1])
|
||||
prompt += "\nAnswer:"
|
||||
if include_answer:
|
||||
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||
return prompt
|
||||
|
||||
def gen_prompt(self, train_df, subject, k=-1):
|
||||
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
|
||||
self.format_subject(subject))
|
||||
if k == -1:
|
||||
k = train_df.shape[0]
|
||||
for i in range(k):
|
||||
prompt += self.format_example(train_df, i)
|
||||
return prompt
|
||||
|
||||
def generate_samples(self) -> Iterable[tuple]:
|
||||
for subject in self.SUBJECT_TO_SUBCATEGORIES.keys():
|
||||
dev_df = pd.read_csv(f"{self.dataset_path}/dev/{subject}_dev.csv",
|
||||
header=None)
|
||||
train_prompt = self.gen_prompt(dev_df, subject, self.num_train)
|
||||
|
||||
test_df = pd.read_csv(
|
||||
f"{self.dataset_path}/test/{subject}_test.csv", header=None)
|
||||
if self.num_samples_per_subject is not None and self.num_samples_per_subject < test_df.shape[
|
||||
0]:
|
||||
test_df = test_df.sample(self.num_samples_per_subject)
|
||||
|
||||
for i in range(test_df.shape[0]):
|
||||
prompt_end = self.format_example(test_df,
|
||||
i,
|
||||
include_answer=False)
|
||||
prompt = train_prompt + prompt_end
|
||||
label = test_df.iloc[i, test_df.shape[1] - 1]
|
||||
yield prompt, label, subject
|
||||
|
||||
def compute_score(self, outputs: List[RequestOutput], references: List[str],
|
||||
subjects: List[str]) -> float:
|
||||
subject_corrections = {
|
||||
key: []
|
||||
for key in self.SUBJECT_TO_SUBCATEGORIES.keys()
|
||||
}
|
||||
for output, ref, sub in zip(outputs, references, subjects):
|
||||
correction = output.outputs[0].text.strip().startswith(ref)
|
||||
subject_corrections[sub].append(correction)
|
||||
|
||||
subcategory_corrections = {
|
||||
key: []
|
||||
for subcats in self.SUBJECT_TO_SUBCATEGORIES.values()
|
||||
for key in subcats
|
||||
}
|
||||
category_corrections = {
|
||||
key: []
|
||||
for key in self.CATEGORY_TO_SUBCATEGORIES.keys()
|
||||
}
|
||||
all_corrections = []
|
||||
for sub, corrections in subject_corrections.items():
|
||||
for subcat in self.SUBJECT_TO_SUBCATEGORIES[sub]:
|
||||
subcategory_corrections[subcat].extend(corrections)
|
||||
for cat, subcats in self.CATEGORY_TO_SUBCATEGORIES.items():
|
||||
if subcat in subcats:
|
||||
category_corrections[cat].extend(corrections)
|
||||
all_corrections.extend(corrections)
|
||||
|
||||
for subject, corrections in subject_corrections.items():
|
||||
acc = np.mean(corrections) * 100
|
||||
logger.info(
|
||||
f"Average accuracy {acc:.2f} ({len(corrections)}) - {subject}")
|
||||
|
||||
for subcat, corrections in subcategory_corrections.items():
|
||||
acc = np.mean(corrections) * 100
|
||||
logger.info(
|
||||
f"Average accuracy {acc:.2f} ({len(corrections)}) - {subcat}")
|
||||
|
||||
for cat, corrections in category_corrections.items():
|
||||
acc = np.mean(corrections) * 100
|
||||
logger.info(
|
||||
f"Average accuracy {acc:.2f} ({len(corrections)}) - {cat}")
|
||||
|
||||
weighted_acc = np.mean(all_corrections) * 100
|
||||
logger.info(
|
||||
f"MMLU weighted average accuracy: {weighted_acc:.2f} ({len(all_corrections)})"
|
||||
)
|
||||
return weighted_acc
|
||||
|
||||
@click.command("mmlu")
|
||||
@click.option("--dataset_path", type=str, required=True)
|
||||
@click.option("--num_samples", type=int, default=None)
|
||||
@click.option("--num_train", type=int, default=5)
|
||||
@click.option("--random_seed", type=int, default=0)
|
||||
@click.option("--max_input_length", type=int, default=4094)
|
||||
@click.option("--max_output_length", type=int, default=2)
|
||||
@click.option("--check_accuracy", is_flag=True, default=False)
|
||||
@click.option("--accuracy_threshold", type=float, default=30)
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, dataset_path: str, num_samples: int, num_train: int,
|
||||
random_seed: int, max_input_length: int, max_output_length: int,
|
||||
check_accuracy: bool, accuracy_threshold: float) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_length,
|
||||
truncate_prompt_tokens=max_input_length)
|
||||
evaluator = MMLU(dataset_path,
|
||||
num_samples=num_samples,
|
||||
num_train=num_train,
|
||||
random_seed=random_seed)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
if check_accuracy:
|
||||
assert accuracy >= accuracy_threshold, f"Expected accuracy >= {accuracy_threshold}, but got {accuracy}"
|
||||
@ -3,12 +3,11 @@ from ..executor import CompletionOutput, RequestError
|
||||
from ..sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from .build_cache import BuildCacheConfig
|
||||
from .llm import LLM, RequestOutput
|
||||
from .llm_args import (EagleDecodingConfig, MedusaDecodingConfig,
|
||||
MTPDecodingConfig)
|
||||
from .llm_args import (EagleDecodingConfig, LookaheadDecodingConfig,
|
||||
MedusaDecodingConfig, MTPDecodingConfig)
|
||||
from .llm_utils import (BuildConfig, CalibConfig, CapacitySchedulerPolicy,
|
||||
KvCacheConfig, KvCacheRetentionConfig,
|
||||
LookaheadDecodingConfig, QuantAlgo, QuantConfig,
|
||||
SchedulerConfig)
|
||||
KvCacheConfig, KvCacheRetentionConfig, QuantAlgo,
|
||||
QuantConfig, SchedulerConfig)
|
||||
from .mpi_session import MpiCommSession
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -928,7 +928,7 @@ class LlmArgs:
|
||||
self.moe_expert_parallel_size = -1
|
||||
|
||||
if self.cp_config is None:
|
||||
self.co_config = {}
|
||||
self.cp_config = {}
|
||||
|
||||
self.parallel_config = _ParallelConfig(
|
||||
tp_size=self.tensor_parallel_size,
|
||||
|
||||
@ -343,6 +343,90 @@ class ModelLoader:
|
||||
|
||||
assert self.speculative_model_obj.is_local_model
|
||||
|
||||
def _update_from_hf_quant_config(self) -> bool:
|
||||
"""Update quant_config from the config file of pre-quantized HF checkpoint.
|
||||
|
||||
Returns:
|
||||
prequantized (bool): Whether the checkpoint is pre-quantized.
|
||||
"""
|
||||
quant_config = self.llm_args.quant_config
|
||||
|
||||
hf_quant_config_path = f"{self._model_dir}/hf_quant_config.json"
|
||||
if os.path.exists(hf_quant_config_path):
|
||||
logger.info(
|
||||
f"Found {hf_quant_config_path}, pre-quantized checkpoint is used."
|
||||
)
|
||||
with open(hf_quant_config_path, "r") as f:
|
||||
hf_quant_config = json.load(f)
|
||||
hf_quant_config = hf_quant_config["quantization"]
|
||||
|
||||
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
|
||||
if hf_quant_algo is not None:
|
||||
hf_quant_algo = QuantAlgo(hf_quant_algo)
|
||||
if quant_config.quant_algo is None:
|
||||
logger.info(
|
||||
f"Setting quant_algo={hf_quant_algo} form HF quant config."
|
||||
)
|
||||
quant_config.quant_algo = hf_quant_algo
|
||||
elif quant_config.quant_algo != hf_quant_algo:
|
||||
raise ValueError(
|
||||
f"Specified quant_algo={quant_config.quant_algo}, conflicting with quant_algo={hf_quant_algo} from HF quant config."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Pre-quantized checkpoint must have quant_algo.")
|
||||
|
||||
hf_kv_cache_quant_algo = hf_quant_config.pop(
|
||||
"kv_cache_quant_algo", None)
|
||||
if hf_kv_cache_quant_algo is not None:
|
||||
hf_kv_cache_quant_algo = QuantAlgo(hf_kv_cache_quant_algo)
|
||||
if quant_config.kv_cache_quant_algo is None:
|
||||
logger.info(
|
||||
f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} form HF quant config."
|
||||
)
|
||||
quant_config.kv_cache_quant_algo = hf_kv_cache_quant_algo
|
||||
elif quant_config.kv_cache_quant_algo != hf_kv_cache_quant_algo:
|
||||
raise ValueError(
|
||||
f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, conflicting with kv_cache_quant_algo={hf_kv_cache_quant_algo} from HF quant config."
|
||||
)
|
||||
else:
|
||||
if quant_config.kv_cache_quant_algo not in [
|
||||
None, QuantAlgo.FP8
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Only kv_cache_quant_algo={QuantAlgo.FP8} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}."
|
||||
)
|
||||
|
||||
for key, value in hf_quant_config.items():
|
||||
logger.info(f"Setting {key}={value} from HF quant config.")
|
||||
setattr(quant_config, key, value)
|
||||
|
||||
return True
|
||||
|
||||
hf_config_path = f"{self._model_dir}/config.json"
|
||||
if os.path.exists(hf_config_path):
|
||||
with open(hf_config_path, "r") as f:
|
||||
hf_config = json.load(f)
|
||||
hf_quant_config = hf_config.get("quantization_config", None)
|
||||
|
||||
if hf_quant_config is not None:
|
||||
logger.info(
|
||||
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
|
||||
)
|
||||
# DeepSeek V3 FP8 ckpt
|
||||
if hf_quant_config.get(
|
||||
"quant_method") == "fp8" and hf_quant_config.get(
|
||||
"weight_block_size"):
|
||||
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
|
||||
quant_config.exclude_modules = ["*eh_proj"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported quantization_config: {hf_quant_config}.")
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _load_model_from_hf(self):
|
||||
''' Load a TRT-LLM model from a HF model. '''
|
||||
assert self._model_dir is not None
|
||||
@ -353,46 +437,7 @@ class ModelLoader:
|
||||
if hasattr(self.llm_args, "speculative_model")
|
||||
and self.llm_args.speculative_model else None)
|
||||
|
||||
# Update quant_config if it's ModelOpt quantized ckpt
|
||||
user_quant_config = self.llm_args.quant_config
|
||||
hf_quant_config_path = Path(self._model_dir) / "hf_quant_config.json"
|
||||
if hf_quant_config_path.exists():
|
||||
logger.info(
|
||||
f"Found {hf_quant_config_path}, pre-quantized checkpoints are used."
|
||||
)
|
||||
already_quantized = True
|
||||
with open(hf_quant_config_path, "r") as f:
|
||||
hf_quant_config = json.load(f)
|
||||
hf_quant_algo = hf_quant_config["quantization"].get(
|
||||
"quant_algo")
|
||||
if hf_quant_algo == "FP8" and user_quant_config.quant_algo \
|
||||
and user_quant_config.quant_algo != QuantAlgo.FP8:
|
||||
raise ValueError(
|
||||
f"Expecting quant_algo to be FP8, got {user_quant_config.quant_algo}."
|
||||
)
|
||||
user_quant_config.quant_algo = hf_quant_algo
|
||||
logger.info(f"quant_algo is set to {hf_quant_algo}")
|
||||
|
||||
hf_kv_cache_quant_algo = hf_quant_config["quantization"].get(
|
||||
"kv_cache_quant_algo")
|
||||
if hf_kv_cache_quant_algo != user_quant_config.kv_cache_quant_algo:
|
||||
if user_quant_config.kv_cache_quant_algo is None:
|
||||
user_quant_config.kv_cache_quant_algo = hf_kv_cache_quant_algo
|
||||
logger.info(
|
||||
f"kv_cache_quant_algo is set to {hf_kv_cache_quant_algo}"
|
||||
)
|
||||
elif user_quant_config.kv_cache_quant_algo == QuantAlgo.FP8 and hf_kv_cache_quant_algo is None:
|
||||
logger.warning(
|
||||
f"User specified kv_cache_quant_algo {user_quant_config.kv_cache_quant_algo} "
|
||||
f"will overwrite {hf_kv_cache_quant_algo} from {hf_quant_config_path}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"User specified kv_cache_quant_algo {user_quant_config.kv_cache_quant_algo}, "
|
||||
f"while it's {hf_kv_cache_quant_algo} in {hf_quant_config_path}."
|
||||
)
|
||||
else:
|
||||
already_quantized = False
|
||||
prequantized = self._update_from_hf_quant_config()
|
||||
|
||||
# FP4 Gemm force to use plugin.
|
||||
if self.llm_args.quant_config.quant_mode.has_nvfp4():
|
||||
@ -407,7 +452,7 @@ class ModelLoader:
|
||||
**self.convert_checkpoint_options,
|
||||
)
|
||||
self.model = model_cls(config)
|
||||
elif self.llm_args.quant_config._requires_calibration and not already_quantized:
|
||||
elif self.llm_args.quant_config._requires_calibration and not prequantized:
|
||||
assert self.workspace is not None
|
||||
checkpoint_dir = f"{self.workspace}/quantized-checkpoint"
|
||||
if self.rank == 0:
|
||||
@ -598,7 +643,7 @@ class CachedModelLoader:
|
||||
self.llm_build_stats.engine_dir = self.model_loader.model_obj.model_dir
|
||||
return self.llm_build_stats.engine_dir, self._hf_model_dir
|
||||
|
||||
if (self.llm_args.backend is not None):
|
||||
if self.llm_args.backend is not None:
|
||||
if self.llm_args.backend not in ["pytorch", "autodeploy"]:
|
||||
raise ValueError(
|
||||
f'backend {self.llm_args.backend} is not supported.')
|
||||
@ -615,6 +660,10 @@ class CachedModelLoader:
|
||||
logger.warning(
|
||||
"QuantConfig for pytorch backend is ignored. You can load"
|
||||
"quantized model with hf_quant_config.json directly.")
|
||||
# Currently, this is to make updated quant_config visible by llm.args.quant_config
|
||||
# TODO: Unify the logics with those in tensorrt_llm/_torch/model_config.py
|
||||
self.model_loader._update_from_hf_quant_config()
|
||||
|
||||
return None, self._hf_model_dir
|
||||
|
||||
return self._build_model(), self._hf_model_dir
|
||||
|
||||
@ -1190,10 +1190,11 @@ def postprocess_fp8_rowwise(tllm_key, weights, **kwargs):
|
||||
tllm_key.replace("weight", "per_channel_scale"): scales
|
||||
}
|
||||
else:
|
||||
x = torch.cat(weights, dim=0).to(torch.float32)
|
||||
clamp_val = config.quantization.clamp_val
|
||||
weights = torch.cat(weights, dim=0)
|
||||
# activation range bound.
|
||||
x = weights.to(torch.float32).clamp(clamp_val[0], clamp_val[1])
|
||||
if clamp_val is not None:
|
||||
# activation range bound.
|
||||
x = x.clamp(clamp_val[0], clamp_val[1])
|
||||
xmax = x.abs().max(-1, keepdim=True).values
|
||||
# minimum scaling factor.
|
||||
torch_weight_scales = (xmax / 448.0).clamp(min=1.0 / (448.0 * 512.0))
|
||||
|
||||
@ -55,7 +55,7 @@ pip install -r requirements-dev.txt
|
||||
cd tests/integration/defs
|
||||
|
||||
# example 1: run a case
|
||||
pytest "accuracy/test_accuracy.py::TestGpt2CnnDailymail::test_auto_dtype"
|
||||
pytest "accuracy/test_cli_flow.py::TestGpt2CnnDailymail::test_auto_dtype"
|
||||
|
||||
# example 2: run a test list
|
||||
pytest --rootdir . --test-list=<a txt file contains on test case per line>
|
||||
|
||||
@ -18,7 +18,7 @@ Since we care about accuracy regression only, so it should be a one-tailed hypot
|
||||
* Null Hypothesis ($H_0$): $x'_1, x'_2, \dots, x'_n$ are drawn from a distribution with a mean equal to or higher than the reference.
|
||||
* Alternative Hypothesis ($H_1$): $x'_1, x'_2, \dots, x'_n$ are drawn from a distribution with a mean lower than the reference.
|
||||
|
||||

|
||||

|
||||
|
||||
### Two-sample t-test
|
||||
|
||||
|
||||
@ -15,13 +15,17 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import pytest
|
||||
import scipy
|
||||
import yaml
|
||||
|
||||
import tensorrt_llm.evaluate
|
||||
from tensorrt_llm._torch import LLM as PyTorchLLM
|
||||
from tensorrt_llm.builder import BuildConfig
|
||||
from tensorrt_llm.llmapi import LLM, SamplingParams
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
@ -55,8 +59,7 @@ class AccuracyTask:
|
||||
|
||||
# Dataset
|
||||
DATASET = None
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets"
|
||||
ROUGE_DIR = f"{llm_models_root()}/rouge"
|
||||
DATASET_DIR = None
|
||||
HIGHER_IS_BETTER = True
|
||||
|
||||
# Hypothesis testing parameters
|
||||
@ -125,9 +128,38 @@ class AccuracyTask:
|
||||
"===========================================================\n")
|
||||
return num_samples, threshold
|
||||
|
||||
def create_evaluator(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def evaluate(self,
|
||||
llm: Union[LLM, PyTorchLLM],
|
||||
extra_acc_spec: Optional[str] = None):
|
||||
spec_dec_algo = None
|
||||
if llm.args.speculative_config is not None:
|
||||
spec_dec_algo = llm.args.speculative_config.decoding_type
|
||||
|
||||
num_samples, threshold = self.get_num_samples_and_threshold(
|
||||
dtype=llm.args.dtype,
|
||||
quant_algo=llm.args.quant_config.quant_algo,
|
||||
kv_cache_quant_algo=llm.args.quant_config.kv_cache_quant_algo,
|
||||
spec_dec_algo=spec_dec_algo,
|
||||
extra_acc_spec=extra_acc_spec)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=self.MAX_OUTPUT_LEN,
|
||||
truncate_prompt_tokens=self.MAX_INPUT_LEN)
|
||||
evaluator = self.create_evaluator(num_samples=num_samples)
|
||||
accuracy = evaluator.evaluate(llm, sampling_params)
|
||||
if self.HIGHER_IS_BETTER:
|
||||
assert accuracy >= threshold, f"Expected accuracy >= {threshold}, but got {accuracy}"
|
||||
else:
|
||||
assert accuracy <= threshold, f"Expected accuracy <= {threshold}, but got {accuracy}"
|
||||
|
||||
|
||||
class CnnDailymail(AccuracyTask):
|
||||
DATASET = "cnn_dailymail"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/ccdv/cnn_dailymail"
|
||||
ROUGE_DIR = f"{llm_models_root()}/rouge"
|
||||
|
||||
ALPHA = 0.002
|
||||
BETA = 0.2
|
||||
@ -138,9 +170,17 @@ class CnnDailymail(AccuracyTask):
|
||||
MAX_INPUT_LEN = 924
|
||||
MAX_OUTPUT_LEN = 100
|
||||
|
||||
def create_evaluator(self, **kwargs):
|
||||
return tensorrt_llm.evaluate.CnnDailymail(dataset_path=self.DATASET_DIR,
|
||||
random_seed=0,
|
||||
rouge_path=self.ROUGE_DIR,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Humaneval(AccuracyTask):
|
||||
DATASET = "humaneval"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/openai_humaneval"
|
||||
ROUGE_DIR = f"{llm_models_root()}/rouge"
|
||||
|
||||
ALPHA = 0.002
|
||||
BETA = 0.2
|
||||
@ -154,6 +194,8 @@ class Humaneval(AccuracyTask):
|
||||
|
||||
class ZeroScrolls(AccuracyTask):
|
||||
DATASET = "zero_scrolls"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/tau/zero_scrolls"
|
||||
ROUGE_DIR = f"{llm_models_root()}/rouge"
|
||||
|
||||
ALPHA = 0.002
|
||||
BETA = 0.2
|
||||
@ -167,9 +209,11 @@ class ZeroScrolls(AccuracyTask):
|
||||
|
||||
class SlimPajama6B(AccuracyTask):
|
||||
DATASET = "SlimPajama-6B"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/SlimPajama-6B"
|
||||
HIGHER_IS_BETTER = False
|
||||
ROUGE_DIR = f"{llm_models_root()}/rouge"
|
||||
|
||||
ALPHA = 0.002
|
||||
ALPHA = 0.01
|
||||
BETA = 0.2
|
||||
SIGMA = 4.48
|
||||
NUM_SAMPLES = 86 # Full sample with length >= 10000
|
||||
@ -180,11 +224,11 @@ class SlimPajama6B(AccuracyTask):
|
||||
MAX_OUTPUT_LEN = 1
|
||||
|
||||
|
||||
class Mmlu(AccuracyTask):
|
||||
class MMLU(AccuracyTask):
|
||||
DATASET = "mmlu"
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/mmlu"
|
||||
|
||||
ALPHA = 0.002
|
||||
ALPHA = 0.01
|
||||
BETA = 0.2
|
||||
SIGMA = 50
|
||||
NUM_SAMPLES = 4096
|
||||
@ -193,6 +237,11 @@ class Mmlu(AccuracyTask):
|
||||
MAX_INPUT_LEN = 4094
|
||||
MAX_OUTPUT_LEN = 2
|
||||
|
||||
def create_evaluator(self, **kwargs):
|
||||
return tensorrt_llm.evaluate.MMLU(dataset_path=self.DATASET_DIR,
|
||||
random_seed=0,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class PassKeyRetrieval64k(AccuracyTask):
|
||||
DATASET = "passkey_retrieval_64k"
|
||||
@ -224,7 +273,7 @@ class PassKeyRetrieval128k(AccuracyTask):
|
||||
MAX_OUTPUT_LEN = 50
|
||||
|
||||
|
||||
class AccuracyTestHarness:
|
||||
class CliFlowAccuracyTestHarness:
|
||||
# Model
|
||||
MODEL_NAME = None
|
||||
MODEL_PATH = None
|
||||
@ -295,22 +344,22 @@ class AccuracyTestHarness:
|
||||
def convert(self):
|
||||
print("Converting model to TensorRT-LLM checkpoint...")
|
||||
|
||||
is_pre_quantized = False
|
||||
is_prequantized = False
|
||||
for quant_config_file in [
|
||||
"hf_quant_config.json", "quant_config.json",
|
||||
"quantize_config.json"
|
||||
]:
|
||||
if exists(f"{self.MODEL_PATH}/{quant_config_file}"):
|
||||
is_pre_quantized = True
|
||||
is_prequantized = True
|
||||
break
|
||||
if not is_pre_quantized and exists(f"{self.MODEL_PATH}/config.json"):
|
||||
if not is_prequantized and exists(f"{self.MODEL_PATH}/config.json"):
|
||||
with open(f"{self.MODEL_PATH}/config.json") as f:
|
||||
hf_config = json.load(f)
|
||||
if "quantization_config" in hf_config:
|
||||
is_pre_quantized = True
|
||||
is_prequantized = True
|
||||
|
||||
quant_config = QuantConfig(self.quant_algo, self.kv_cache_quant_algo)
|
||||
if not is_pre_quantized and quant_config._requires_modelopt_quantization:
|
||||
if not is_prequantized and quant_config._requires_modelopt_quantization:
|
||||
script = "../quantization/quantize.py"
|
||||
else:
|
||||
script = "convert_checkpoint.py"
|
||||
@ -333,7 +382,7 @@ class AccuracyTestHarness:
|
||||
if self.cp_size > 1:
|
||||
convert_cmd.append(f"--cp_size={self.cp_size}")
|
||||
|
||||
if not is_pre_quantized and quant_config._requires_modelopt_quantization:
|
||||
if not is_prequantized and quant_config._requires_modelopt_quantization:
|
||||
if self.quant_algo == QuantAlgo.MIXED_PRECISION:
|
||||
assert self.extra_convert_args is not None
|
||||
assert any(
|
||||
@ -548,7 +597,7 @@ class AccuracyTestHarness:
|
||||
if isinstance(task,
|
||||
(CnnDailymail, Humaneval, ZeroScrolls, SlimPajama6B)):
|
||||
self.summarize(task)
|
||||
elif isinstance(task, Mmlu):
|
||||
elif isinstance(task, MMLU):
|
||||
self.mmlu(task)
|
||||
elif isinstance(task, (PassKeyRetrieval64k, PassKeyRetrieval128k)):
|
||||
self.eval_long_context(task)
|
||||
@ -589,3 +638,13 @@ class AccuracyTestHarness:
|
||||
self.convert()
|
||||
self.build()
|
||||
self.evaluate()
|
||||
|
||||
|
||||
class LlmapiAccuracyTestHarness:
|
||||
# Model
|
||||
MODEL_NAME = None
|
||||
MODEL_PATH = None
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
logger.set_level("info")
|
||||
|
||||
|
Before Width: | Height: | Size: 23 KiB After Width: | Height: | Size: 23 KiB |
@ -36,12 +36,12 @@ microsoft/Phi-3.5-mini-instruct:
|
||||
state-spaces/mamba-130m-hf:
|
||||
- accuracy: 19.470
|
||||
lmsys/vicuna-7b-v1.3:
|
||||
- spec_dec_algo: lookahead
|
||||
- spec_dec_algo: Lookahead
|
||||
accuracy: 33.427
|
||||
- dtype: float16
|
||||
spec_dec_algo: medusa
|
||||
spec_dec_algo: Medusa
|
||||
accuracy: 33.419
|
||||
- spec_dec_algo: eagle
|
||||
- spec_dec_algo: Eagle
|
||||
accuracy: 27.832
|
||||
llama-7b-hf:
|
||||
- accuracy: 30.457
|
||||
@ -94,6 +94,9 @@ meta-llama/Llama-3.1-8B:
|
||||
- accuracy: 24.360
|
||||
- quant_algo: W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
|
||||
accuracy: 25.004
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 25.469
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 24.359
|
||||
@ -111,7 +114,7 @@ meta-llama/Llama-3.1-8B-Instruct:
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 33.464
|
||||
- dtype: float16
|
||||
spec_dec_algo: medusa
|
||||
spec_dec_algo: Medusa
|
||||
accuracy: 33.663
|
||||
meta-llama/Llama-3.2-1B:
|
||||
- accuracy: 27.427
|
||||
@ -136,8 +139,17 @@ meta-llama/Llama-3.2-1B:
|
||||
accuracy: 27.259
|
||||
- extra_acc_spec: max_attention_window_size=960;beam_width=4
|
||||
accuracy: 0
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 34.383
|
||||
- quant_algo: FP8
|
||||
accuracy: 34.927
|
||||
mistralai/Mixtral-8x7B-v0.1:
|
||||
- accuracy: 28.810
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 29.733
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 27.109
|
||||
@ -206,3 +218,9 @@ nvidia/Nemotron-Mini-4B-Instruct:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 25.72
|
||||
deepseek-ai/DeepSeek-V3-Lite:
|
||||
- accuracy: 25.682
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 25.243
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
accuracy: 25.546
|
||||
|
||||
@ -10,11 +10,20 @@ meta-llama/Meta-Llama-3-8B-Instruct:
|
||||
accuracy: 63.47
|
||||
meta-llama/Llama-3.1-8B:
|
||||
- accuracy: 66.06
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 63.16
|
||||
- quant_algo: FP8_PER_CHANNEL_PER_TOKEN
|
||||
accuracy: 65.55
|
||||
- quant_algo: MIXED_PRECISION
|
||||
extra_acc_spec: autoq_format=int4_awq,fp8,w4a8_awq;auto_quantize_bits=5.8
|
||||
accuracy: 64.99
|
||||
meta-llama/Llama-3.3-70B-Instruct:
|
||||
- quant_algo: NVFP4
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 79.31
|
||||
- quant_algo: FP8
|
||||
accuracy: 81.02
|
||||
mistralai/Mixtral-8x7B-v0.1:
|
||||
- accuracy: 71.35
|
||||
- quant_algo: FP8
|
||||
@ -33,3 +42,9 @@ Qwen/Qwen2.5-1.5B-Instruct:
|
||||
- accuracy: 61.45
|
||||
- quant_algo: FP8
|
||||
accuracy: 61.43
|
||||
deepseek-ai/DeepSeek-V3-Lite:
|
||||
- accuracy: 71.47
|
||||
- quant_algo: NVFP4
|
||||
accuracy: 70.66
|
||||
- quant_algo: FP8_BLOCK_SCALES
|
||||
accuracy: 71.32
|
||||
|
||||
@ -18,7 +18,7 @@ import re
|
||||
import pandas as pd
|
||||
|
||||
metric_regex = {
|
||||
"rouge1": r"(?<=rouge1 : )\d+\.\d+",
|
||||
"rouge1": r"(?<=rouge1: )\d+\.\d+",
|
||||
"perplexity": r"(?<=Per-token perplexity: )\d+\.\d+",
|
||||
"mmlu": r"(?<=MMLU weighted average accuracy: )\d+\.\d+",
|
||||
"passkey": r"(?<=passkey accuracy: )\d+\.\d+"
|
||||
|
||||
@ -14,16 +14,18 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.llmapi import (EagleDecodingConfig, LookaheadDecodingConfig,
|
||||
MedusaDecodingConfig)
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import (llm_models_root, skip_no_nvls, skip_pre_ada,
|
||||
skip_pre_blackwell, skip_pre_hopper)
|
||||
from .accuracy_core import (AccuracyTestHarness, CnnDailymail, Humaneval, Mmlu,
|
||||
PassKeyRetrieval64k, PassKeyRetrieval128k,
|
||||
SlimPajama6B, ZeroScrolls)
|
||||
from .accuracy_core import (MMLU, CliFlowAccuracyTestHarness, CnnDailymail,
|
||||
Humaneval, PassKeyRetrieval64k,
|
||||
PassKeyRetrieval128k, SlimPajama6B, ZeroScrolls)
|
||||
|
||||
|
||||
class TestGpt2(AccuracyTestHarness):
|
||||
class TestGpt2(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "gpt2"
|
||||
MODEL_PATH = f"{llm_models_root()}/gpt2"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -99,7 +101,7 @@ class TestGpt2(AccuracyTestHarness):
|
||||
self.run(extra_summarize_args=["--cuda_graph_mode"])
|
||||
|
||||
|
||||
class TestGpt2Medium(AccuracyTestHarness):
|
||||
class TestGpt2Medium(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "gpt2-medium"
|
||||
MODEL_PATH = f"{llm_models_root()}/gpt2-medium"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -117,7 +119,7 @@ class TestGpt2Medium(AccuracyTestHarness):
|
||||
extra_convert_args=["--quantize_lm_head"])
|
||||
|
||||
|
||||
class TestSantacoder(AccuracyTestHarness):
|
||||
class TestSantacoder(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/santacoder"
|
||||
MODEL_PATH = f"{llm_models_root()}/santacoder"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -127,7 +129,7 @@ class TestSantacoder(AccuracyTestHarness):
|
||||
self.run(tasks=[Humaneval(self.MODEL_NAME)], dtype='auto')
|
||||
|
||||
|
||||
class TestStarcoder2_3B(AccuracyTestHarness):
|
||||
class TestStarcoder2_3B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/starcoder2-3b"
|
||||
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -136,7 +138,7 @@ class TestStarcoder2_3B(AccuracyTestHarness):
|
||||
self.run(tasks=[Humaneval(self.MODEL_NAME)], dtype='auto')
|
||||
|
||||
|
||||
class TestStarcoder2_15B(AccuracyTestHarness):
|
||||
class TestStarcoder2_15B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "bigcode/starcoder2-15b"
|
||||
MODEL_PATH = f"{llm_models_root()}/starcoder2-model"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -146,7 +148,7 @@ class TestStarcoder2_15B(AccuracyTestHarness):
|
||||
quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL)
|
||||
|
||||
|
||||
class TestGptNext(AccuracyTestHarness):
|
||||
class TestGptNext(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "gpt-next"
|
||||
MODEL_PATH = f"{llm_models_root()}/gpt-next/megatron_converted_843m_tp1_pp1.nemo"
|
||||
MODEL_FORMAT = "NEMO"
|
||||
@ -157,7 +159,7 @@ class TestGptNext(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestMinitron4BBase(AccuracyTestHarness):
|
||||
class TestMinitron4BBase(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Minitron-4B-Base"
|
||||
MODEL_PATH = f"{llm_models_root()}/nemotron/Minitron-4B-Base"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
@ -174,13 +176,13 @@ class TestMinitron4BBase(AccuracyTestHarness):
|
||||
kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
|
||||
|
||||
class TestNemotronMini4BInstruct(AccuracyTestHarness):
|
||||
class TestNemotronMini4BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "nvidia/Nemotron-Mini-4B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/nemotron/Nemotron-Mini-4B-Instruct"
|
||||
EXAMPLE_FOLDER = "gpt"
|
||||
|
||||
@skip_pre_ada
|
||||
def test_fp8_pre_quantized(self, mocker):
|
||||
def test_fp8_prequantized(self, mocker):
|
||||
mocker.patch.object(
|
||||
self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/nemotron/nemotron-mini-4b-instruct_vfp8-fp8-bf16-export"
|
||||
@ -188,7 +190,7 @@ class TestNemotronMini4BInstruct(AccuracyTestHarness):
|
||||
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
|
||||
|
||||
class TestPhi2(AccuracyTestHarness):
|
||||
class TestPhi2(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/phi-2"
|
||||
MODEL_PATH = f"{llm_models_root()}/phi-2"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -201,7 +203,7 @@ class TestPhi2(AccuracyTestHarness):
|
||||
self.run(tp_size=2)
|
||||
|
||||
|
||||
class TestPhi3Mini4kInstruct(AccuracyTestHarness):
|
||||
class TestPhi3Mini4kInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -210,7 +212,7 @@ class TestPhi3Mini4kInstruct(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestPhi3Mini128kInstruct(AccuracyTestHarness):
|
||||
class TestPhi3Mini128kInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-3-mini-128k-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-3/Phi-3-mini-128k-instruct"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -219,7 +221,7 @@ class TestPhi3Mini128kInstruct(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestPhi3Small8kInstruct(AccuracyTestHarness):
|
||||
class TestPhi3Small8kInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-3-small-8k-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-3/Phi-3-small-8k-instruct"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -228,7 +230,7 @@ class TestPhi3Small8kInstruct(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestPhi3Small128kInstruct(AccuracyTestHarness):
|
||||
class TestPhi3Small128kInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-3-small-128k-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-3/Phi-3-small-128k-instruct"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -237,7 +239,7 @@ class TestPhi3Small128kInstruct(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestPhi3_5MiniInstruct(AccuracyTestHarness):
|
||||
class TestPhi3_5MiniInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-3.5-mini-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-3.5/Phi-3.5-mini-instruct"
|
||||
EXAMPLE_FOLDER = "phi"
|
||||
@ -249,7 +251,7 @@ class TestPhi3_5MiniInstruct(AccuracyTestHarness):
|
||||
# Long sequence length test:
|
||||
# Model FP16 7B + 32K tokens in KV cache = 14 * 1024 MB + 32K * 0.5 MB = 30720 MB + scratch memory
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
class TestLongAlpaca7B(AccuracyTestHarness):
|
||||
class TestLongAlpaca7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Yukang/LongAlpaca-7B"
|
||||
MODEL_PATH = f"{llm_models_root()}/LongAlpaca-7B"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -267,7 +269,7 @@ class TestLongAlpaca7B(AccuracyTestHarness):
|
||||
})
|
||||
|
||||
|
||||
class TestMamba130M(AccuracyTestHarness):
|
||||
class TestMamba130M(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "state-spaces/mamba-130m-hf"
|
||||
MODEL_PATH = f"{llm_models_root()}/mamba/mamba-130m-hf"
|
||||
EXAMPLE_FOLDER = "mamba"
|
||||
@ -276,7 +278,7 @@ class TestMamba130M(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
|
||||
class TestVicuna7B(AccuracyTestHarness):
|
||||
class TestVicuna7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "lmsys/vicuna-7b-v1.3"
|
||||
MODEL_PATH = f"{llm_models_root()}/vicuna-7b-v1.3"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -288,7 +290,7 @@ class TestVicuna7B(AccuracyTestHarness):
|
||||
def test_lookahead(self, mocker):
|
||||
mocker.patch.object(CnnDailymail, "MAX_BATCH_SIZE", 8)
|
||||
|
||||
self.run(spec_dec_algo="lookahead",
|
||||
self.run(spec_dec_algo=LookaheadDecodingConfig.decoding_type,
|
||||
extra_build_args=[
|
||||
"--max_draft_len=83",
|
||||
"--speculative_decoding_mode=lookahead_decoding"
|
||||
@ -308,7 +310,7 @@ class TestVicuna7B(AccuracyTestHarness):
|
||||
extra_summarize_args.append("--cuda_graph_mode")
|
||||
|
||||
self.run(dtype="float16",
|
||||
spec_dec_algo="medusa",
|
||||
spec_dec_algo=MedusaDecodingConfig.decoding_type,
|
||||
extra_convert_args=[
|
||||
f"--medusa_model_dir={self.MEDUSA_MODEL_PATH}",
|
||||
"--num_medusa_heads=4"
|
||||
@ -339,7 +341,7 @@ class TestVicuna7B(AccuracyTestHarness):
|
||||
extra_summarize_args.extend(
|
||||
["--eagle_posterior_threshold=0.09", "--temperature=0.7"])
|
||||
|
||||
self.run(spec_dec_algo="eagle",
|
||||
self.run(spec_dec_algo=EagleDecodingConfig.decoding_type,
|
||||
extra_convert_args=[
|
||||
f"--eagle_model_dir={self.EAGLE_MODEL_PATH}",
|
||||
"--max_draft_len=63", "--num_eagle_layers=4",
|
||||
@ -351,7 +353,7 @@ class TestVicuna7B(AccuracyTestHarness):
|
||||
extra_summarize_args=extra_summarize_args)
|
||||
|
||||
|
||||
class TestLlama7B(AccuracyTestHarness):
|
||||
class TestLlama7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "llama-7b-hf"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-models/llama-7b-hf"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -382,7 +384,7 @@ class TestLlama7B(AccuracyTestHarness):
|
||||
self.run(extra_build_args=["--fast_build"])
|
||||
|
||||
|
||||
class TestLlama2_7B(AccuracyTestHarness):
|
||||
class TestLlama2_7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-models-v2/llama-v2-7b-hf"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -396,7 +398,7 @@ class TestLlama2_7B(AccuracyTestHarness):
|
||||
@skip_pre_ada
|
||||
def test_fp8(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
|
||||
@ -445,14 +447,14 @@ class TestLlama2_7B(AccuracyTestHarness):
|
||||
self.run(quant_algo=QuantAlgo.W4A16_AWQ, tp_size=2)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_int4_awq_pre_quantized_tp2(self, mocker):
|
||||
def test_int4_awq_prequantized_tp2(self, mocker):
|
||||
mocker.patch.object(
|
||||
self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/llama-models-v2/Llama-2-7B-AWQ")
|
||||
self.run(quant_algo=QuantAlgo.W4A16_AWQ, tp_size=2)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_int4_gptq_pre_quantized_tp2(self, mocker):
|
||||
def test_int4_gptq_prequantized_tp2(self, mocker):
|
||||
mocker.patch.object(
|
||||
self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/llama-models-v2/Llama-2-7B-GPTQ")
|
||||
@ -462,7 +464,7 @@ class TestLlama2_7B(AccuracyTestHarness):
|
||||
self.run(extra_build_args=["--weight_sparsity"])
|
||||
|
||||
|
||||
class TestTinyLlama1_1BChat(AccuracyTestHarness):
|
||||
class TestTinyLlama1_1BChat(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -498,7 +500,7 @@ class TestTinyLlama1_1BChat(AccuracyTestHarness):
|
||||
self.run(extra_acc_spec="pp_size=4", pp_size=4)
|
||||
|
||||
|
||||
class TestLlama3_8BInstruct(AccuracyTestHarness):
|
||||
class TestLlama3_8BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-models-v3/llama-v3-8b-instruct-hf"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -519,7 +521,7 @@ class TestLlama3_8BInstruct(AccuracyTestHarness):
|
||||
|
||||
@skip_pre_blackwell
|
||||
def test_nvfp4(self):
|
||||
self.run(tasks=[Mmlu(self.MODEL_NAME)],
|
||||
self.run(tasks=[MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.NVFP4,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8,
|
||||
extra_build_args=["--gemm_plugin=disable"])
|
||||
@ -540,13 +542,13 @@ class TestLlama3_8BInstruct(AccuracyTestHarness):
|
||||
])
|
||||
if norm_quant_fusion:
|
||||
extra_build_args.append("--norm_quant_fusion=enable")
|
||||
self.run(tasks=[Mmlu(self.MODEL_NAME)],
|
||||
self.run(tasks=[MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.NVFP4,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8,
|
||||
extra_build_args=extra_build_args)
|
||||
|
||||
|
||||
class TestLlama3_8BInstructGradient1048k(AccuracyTestHarness):
|
||||
class TestLlama3_8BInstructGradient1048k(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-models-v3/Llama-3-8B-Instruct-Gradient-1048k"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -561,7 +563,7 @@ class TestLlama3_8BInstructGradient1048k(AccuracyTestHarness):
|
||||
extra_build_args=["--gather_context_logits"])
|
||||
|
||||
|
||||
class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -579,7 +581,7 @@ class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
@skip_pre_ada
|
||||
def test_fp8_rowwise(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
|
||||
|
||||
@skip_pre_ada
|
||||
@ -598,7 +600,7 @@ class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
extra_build_args = ["--gemm_allreduce_plugin=bfloat16"]
|
||||
self.run(
|
||||
tasks=[PassKeyRetrieval64k(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
tp_size=4,
|
||||
extra_build_args=extra_build_args)
|
||||
|
||||
@ -613,7 +615,7 @@ class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
extra_build_args = ["--gemm_allreduce_plugin=bfloat16"]
|
||||
self.run(
|
||||
tasks=[PassKeyRetrieval64k(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN,
|
||||
tp_size=4,
|
||||
extra_build_args=extra_build_args)
|
||||
@ -621,7 +623,7 @@ class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
@skip_pre_ada
|
||||
def test_autoq(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.MIXED_PRECISION,
|
||||
extra_acc_spec=
|
||||
"autoq_format=int4_awq,fp8,w4a8_awq;auto_quantize_bits=5.8",
|
||||
@ -632,7 +634,7 @@ class TestLlama3_1_8B(AccuracyTestHarness):
|
||||
])
|
||||
|
||||
|
||||
class TestLlama3_1_8BInstruct(AccuracyTestHarness):
|
||||
class TestLlama3_1_8BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -641,14 +643,14 @@ class TestLlama3_1_8BInstruct(AccuracyTestHarness):
|
||||
self.run(dtype='auto')
|
||||
|
||||
@skip_pre_ada
|
||||
def test_fp8_pre_quantized(self, mocker):
|
||||
def test_fp8_prequantized(self, mocker):
|
||||
mocker.patch.object(
|
||||
self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8")
|
||||
self.run(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
|
||||
@skip_pre_ada
|
||||
def test_medusa_fp8_pre_quantized(self, mocker):
|
||||
def test_medusa_fp8_prequantized(self, mocker):
|
||||
# nvidia/Llama-3.1-8B-Medusa-FP8
|
||||
mocker.patch.object(self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/llama3.1-medusa-8b-hf_v0.1")
|
||||
@ -659,12 +661,12 @@ class TestLlama3_1_8BInstruct(AccuracyTestHarness):
|
||||
"--medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]"
|
||||
]
|
||||
self.run(dtype="float16",
|
||||
spec_dec_algo="medusa",
|
||||
spec_dec_algo=MedusaDecodingConfig.decoding_type,
|
||||
extra_build_args=["--speculative_decoding_mode=medusa"],
|
||||
extra_summarize_args=extra_summarize_args)
|
||||
|
||||
|
||||
class TestLlama3_2_1B(AccuracyTestHarness):
|
||||
class TestLlama3_2_1B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.2-models/Llama-3.2-1B"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
@ -768,14 +770,14 @@ class TestLlama3_2_1B(AccuracyTestHarness):
|
||||
])
|
||||
|
||||
|
||||
class TestMixtral8x7B(AccuracyTestHarness):
|
||||
class TestMixtral8x7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
|
||||
MODEL_PATH = f"{llm_models_root()}/Mixtral-8x7B-v0.1"
|
||||
EXAMPLE_FOLDER = "llama"
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_auto_dtype(self):
|
||||
def test_tp2(self):
|
||||
self.run(dtype='auto', tp_size=2)
|
||||
|
||||
@skip_pre_ada
|
||||
@ -791,7 +793,7 @@ class TestMixtral8x7B(AccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
def test_fp8_tp2pp2(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8,
|
||||
tp_size=2,
|
||||
@ -802,7 +804,7 @@ class TestMixtral8x7B(AccuracyTestHarness):
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
def test_fp8_tp2pp2_manage_weights(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8,
|
||||
tp_size=2,
|
||||
@ -810,16 +812,16 @@ class TestMixtral8x7B(AccuracyTestHarness):
|
||||
extra_build_args=["--fast_build"])
|
||||
|
||||
@skip_pre_blackwell
|
||||
def test_nvfp4_pre_quantized(self, mocker):
|
||||
def test_nvfp4_prequantized(self, mocker):
|
||||
mocker.patch.object(
|
||||
self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/nvfp4-quantized/Mixtral-8x7B-Instruct-v0.1")
|
||||
self.run(tasks=[Mmlu(self.MODEL_NAME)],
|
||||
self.run(tasks=[MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.NVFP4,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8)
|
||||
|
||||
|
||||
class TestGemma2B(AccuracyTestHarness):
|
||||
class TestGemma2B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "google/gemma-2b"
|
||||
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-2b"
|
||||
EXAMPLE_FOLDER = "gemma"
|
||||
@ -848,7 +850,7 @@ class TestGemma2B(AccuracyTestHarness):
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
class TestGemma7B(AccuracyTestHarness):
|
||||
class TestGemma7B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "google/gemma-7b"
|
||||
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-7b"
|
||||
EXAMPLE_FOLDER = "gemma"
|
||||
@ -878,14 +880,14 @@ class TestGemma7B(AccuracyTestHarness):
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
class TestGemma2_9BIt(AccuracyTestHarness):
|
||||
class TestGemma2_9BIt(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "google/gemma-2-9b-it"
|
||||
MODEL_PATH = f"{llm_models_root()}/gemma/gemma-2-9b-it"
|
||||
EXAMPLE_FOLDER = "gemma"
|
||||
|
||||
def test_auto_dtype(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
dtype='auto',
|
||||
extra_convert_args=["--ckpt-type=hf"])
|
||||
|
||||
@ -901,7 +903,7 @@ class TestGemma2_9BIt(AccuracyTestHarness):
|
||||
extra_convert_args=["--device_map=sequential"])
|
||||
|
||||
|
||||
class TestQwen7BChat(AccuracyTestHarness):
|
||||
class TestQwen7BChat(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen-7B-Chat"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen-7B-Chat"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -912,14 +914,14 @@ class TestQwen7BChat(AccuracyTestHarness):
|
||||
def test_weight_only(self):
|
||||
self.run(quant_algo=QuantAlgo.W8A16)
|
||||
|
||||
def test_int4_gptq_pre_quantized(self, mocker):
|
||||
def test_int4_gptq_prequantized(self, mocker):
|
||||
mocker.patch.object(self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/Qwen-7B-Chat-Int4")
|
||||
self.run(quant_algo=QuantAlgo.W4A16_GPTQ)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
class TestQwen1_5MoeA2_7BChat(AccuracyTestHarness):
|
||||
class TestQwen1_5MoeA2_7BChat(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen1.5-MoE-A2.7B-Chat"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -932,7 +934,7 @@ class TestQwen1_5MoeA2_7BChat(AccuracyTestHarness):
|
||||
self.run(quant_algo=QuantAlgo.W8A16)
|
||||
|
||||
|
||||
class TestQwen2_0_5BInstruct(AccuracyTestHarness):
|
||||
class TestQwen2_0_5BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen2-0.5B-Instruct"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -946,11 +948,11 @@ class TestQwen2_0_5BInstruct(AccuracyTestHarness):
|
||||
@skip_pre_ada
|
||||
def test_fp8(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8)
|
||||
|
||||
|
||||
class TestQwen2_7BInstruct(AccuracyTestHarness):
|
||||
class TestQwen2_7BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen2-7B-Instruct"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -961,14 +963,14 @@ class TestQwen2_7BInstruct(AccuracyTestHarness):
|
||||
def test_weight_only(self):
|
||||
self.run(quant_algo=QuantAlgo.W8A16)
|
||||
|
||||
def test_int4_awq_pre_quantized(self, mocker):
|
||||
def test_int4_awq_prequantized(self, mocker):
|
||||
mocker.patch.object(self.__class__, "MODEL_PATH",
|
||||
f"{llm_models_root()}/Qwen2-7B-Instruct-AWQ")
|
||||
self.run(quant_algo=QuantAlgo.W4A16_AWQ)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(40000)
|
||||
class TestQwen2_57B_A14B(AccuracyTestHarness):
|
||||
class TestQwen2_57B_A14B(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen2-57B-A14B"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen2-57B-A14B"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -984,7 +986,7 @@ class TestQwen2_57B_A14B(AccuracyTestHarness):
|
||||
self.run(tp_size=2, pp_size=2)
|
||||
|
||||
|
||||
class TestQwen2_5_1_5BInstruct(AccuracyTestHarness):
|
||||
class TestQwen2_5_1_5BInstruct(CliFlowAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Qwen2.5-1.5B-Instruct"
|
||||
EXAMPLE_FOLDER = "qwen"
|
||||
@ -998,5 +1000,5 @@ class TestQwen2_5_1_5BInstruct(AccuracyTestHarness):
|
||||
@skip_pre_ada
|
||||
def test_fp8(self):
|
||||
self.run(tasks=[CnnDailymail(self.MODEL_NAME),
|
||||
Mmlu(self.MODEL_NAME)],
|
||||
MMLU(self.MODEL_NAME)],
|
||||
quant_algo=QuantAlgo.FP8)
|
||||
50
tests/integration/defs/accuracy/test_llm_api.py
Normal file
50
tests/integration/defs/accuracy/test_llm_api.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm.llmapi import LLM
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import llm_models_root, skip_pre_ada
|
||||
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness
|
||||
|
||||
|
||||
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
|
||||
|
||||
@skip_pre_ada
|
||||
def test_fp8_rowwise(self):
|
||||
quant_config = QuantConfig(QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
|
||||
|
||||
with LLM(self.MODEL_PATH, quant_config=quant_config) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestMixtral8x7B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
|
||||
MODEL_PATH = f"{llm_models_root()}/Mixtral-8x7B-v0.1"
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_tp2(self):
|
||||
with LLM(self.MODEL_PATH, tensor_parallel_size=2) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
146
tests/integration/defs/accuracy/test_llm_api_pytorch.py
Normal file
146
tests/integration/defs/accuracy/test_llm_api_pytorch.py
Normal file
@ -0,0 +1,146 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import llm_models_root, skip_pre_blackwell
|
||||
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness
|
||||
|
||||
|
||||
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.1-8B"
|
||||
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
def test_nvfp4(self):
|
||||
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestLlama3_3_70BInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_device_not_contain(["H100", "B200"])
|
||||
def test_fp8_tp4(self):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8"
|
||||
with LLM(model_path, tensor_parallel_size=4) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.skip_device_not_contain(["B200"])
|
||||
def test_nvfp4_tp4(self):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4"
|
||||
with LLM(model_path, tensor_parallel_size=4) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestMixtral8x7B(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1"
|
||||
MODEL_PATH = f"{llm_models_root()}/Mixtral-8x7B-v0.1"
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_tp2(self):
|
||||
with LLM(self.MODEL_PATH, tensor_parallel_size=2) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_device_not_contain(["H100", "B200"])
|
||||
def test_fp8_tp2(self):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8"
|
||||
with LLM(model_path, tensor_parallel_size=2) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.skip_device_not_contain(["B200"])
|
||||
def test_nvfp4_tp2(self):
|
||||
model_path = f"{llm_models_root()}/modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4"
|
||||
with LLM(model_path, tensor_parallel_size=2) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
assert llm.args.quant_config.kv_cache_quant_algo == QuantAlgo.FP8
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite"
|
||||
MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
def test_auto_dtype(self):
|
||||
# https://nvbugs/5141289: OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
with LLM(self.MODEL_PATH, kv_cache_config=kv_cache_config) as llm:
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@pytest.mark.skip_device_not_contain(["H100"])
|
||||
def test_fp8_block_scales(self):
|
||||
model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/fp8"
|
||||
# https://nvbugs/5141289: OOM on H100 with default free_gpu_memory_fraction=0.9
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
|
||||
with LLM(model_path, kv_cache_config=kv_cache_config) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@skip_pre_blackwell
|
||||
def test_nvfp4(self):
|
||||
model_path = f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only"
|
||||
with LLM(model_path) as llm:
|
||||
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
|
||||
task = CnnDailymail(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
@ -1,157 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from defs.common import venv_check_call
|
||||
from defs.conftest import llm_models_root, skip_pre_blackwell
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_fp4",
|
||||
[pytest.param(True, marks=skip_pre_blackwell), False],
|
||||
ids=["enable_fp4", "disable_fp4"])
|
||||
@pytest.mark.parametrize("model_name", ["llama-3.1-8b"])
|
||||
def test_llm_llama_1gpu(
|
||||
mmlu_dataset_root,
|
||||
enable_fp4,
|
||||
llama_example_root,
|
||||
model_name,
|
||||
llm_venv,
|
||||
):
|
||||
models_root = llm_models_root()
|
||||
if enable_fp4:
|
||||
model_dir = os.path.join(models_root, "nvfp4-quantized",
|
||||
"Meta-Llama-3.1-8B")
|
||||
else:
|
||||
model_dir = os.path.join(models_root, "llama-3.1-model",
|
||||
"Meta-Llama-3.1-8B")
|
||||
|
||||
print("Run MMLU test")
|
||||
accuracy_map = {
|
||||
'llama-3.1-8b': 61,
|
||||
}
|
||||
acc_thres = accuracy_map[model_name]
|
||||
mmlu_cmd = [
|
||||
f"{llama_example_root}/../mmlu_llmapi.py",
|
||||
f"--data_dir={mmlu_dataset_root}",
|
||||
f"--hf_model_dir={model_dir}",
|
||||
"--backend=pytorch",
|
||||
"--check_accuracy",
|
||||
"--enable_chunked_prefill",
|
||||
f"--accuracy_threshold={acc_thres}",
|
||||
]
|
||||
|
||||
venv_check_call(llm_venv, mmlu_cmd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_fp4",
|
||||
[pytest.param(True, marks=skip_pre_blackwell), False],
|
||||
ids=["enable_fp4", "disable_fp4"])
|
||||
@pytest.mark.parametrize("enable_fp8", [
|
||||
pytest.param(True, marks=pytest.mark.skip_device_not_contain(["H100"])),
|
||||
False
|
||||
],
|
||||
ids=["enable_fp8", "disable_fp8"])
|
||||
@pytest.mark.parametrize("model_name", ["deepseek-v3-lite"])
|
||||
def test_llm_deepseek_1gpu(
|
||||
mmlu_dataset_root,
|
||||
enable_fp4,
|
||||
enable_fp8,
|
||||
llama_example_root,
|
||||
model_name,
|
||||
llm_venv,
|
||||
):
|
||||
models_root = llm_models_root()
|
||||
if enable_fp4:
|
||||
model_dir = os.path.join(models_root, "DeepSeek-V3-Lite",
|
||||
"nvfp4_moe_only")
|
||||
elif enable_fp8:
|
||||
model_dir = os.path.join(models_root, "DeepSeek-V3-Lite", "fp8")
|
||||
else:
|
||||
model_dir = os.path.join(models_root, "DeepSeek-V3-Lite", "bf16")
|
||||
|
||||
print("Run MMLU test")
|
||||
accuracy_map = {
|
||||
'deepseek-v3-lite': 68,
|
||||
}
|
||||
acc_thres = accuracy_map[model_name]
|
||||
mmlu_cmd = [
|
||||
f"{llama_example_root}/../mmlu_llmapi.py",
|
||||
f"--data_dir={mmlu_dataset_root}",
|
||||
f"--hf_model_dir={model_dir}",
|
||||
"--backend=pytorch",
|
||||
"--check_accuracy",
|
||||
"--enable_overlap_scheduler",
|
||||
"--kv_cache_free_gpu_memory_fraction=0.8",
|
||||
f"--accuracy_threshold={acc_thres}",
|
||||
]
|
||||
|
||||
venv_check_call(llm_venv, mmlu_cmd)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
pytest.param('Llama-3.3-70B-Instruct-fp8',
|
||||
'modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8',
|
||||
marks=pytest.mark.skip_device_not_contain(["B200", "H100"])),
|
||||
pytest.param('Llama-3.3-70B-Instruct-fp4',
|
||||
'modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4',
|
||||
marks=pytest.mark.skip_device_not_contain(["B200"])),
|
||||
])
|
||||
def test_mmlu_llmapi_4gpus(llm_venv, llama_example_root, mmlu_dataset_root,
|
||||
model_name, model_path):
|
||||
models_root = llm_models_root()
|
||||
model_dir = os.path.join(models_root, model_path)
|
||||
|
||||
print(f"Run MMLU test on {model_name}.")
|
||||
accuracy_map = {
|
||||
'Llama-3.3-70B-Instruct-fp8': 80.4,
|
||||
'Llama-3.3-70B-Instruct-fp4': 78.5,
|
||||
}
|
||||
acc_thres = accuracy_map[model_name]
|
||||
mmlu_cmd = [
|
||||
f"{llama_example_root}/../mmlu_llmapi.py",
|
||||
f"--data_dir={mmlu_dataset_root}",
|
||||
f"--hf_model_dir={model_dir}",
|
||||
"--backend=pytorch",
|
||||
"--check_accuracy",
|
||||
"--enable_chunked_prefill",
|
||||
f"--accuracy_threshold={acc_thres}",
|
||||
f"--tp_size=4",
|
||||
]
|
||||
|
||||
venv_check_call(llm_venv, mmlu_cmd)
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(80000)
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
pytest.param('Mixtral-8x7B-Instruct-v0.1-fp8',
|
||||
'modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8',
|
||||
marks=pytest.mark.skip_device_not_contain(["B200", "H100"])),
|
||||
pytest.param('Mixtral-8x7B-Instruct-v0.1-fp4',
|
||||
'modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4',
|
||||
marks=pytest.mark.skip_device_not_contain(["B200"])),
|
||||
])
|
||||
def test_mmlu_llmapi_2gpus(llm_venv, llama_example_root, mmlu_dataset_root,
|
||||
model_name, model_path):
|
||||
models_root = llm_models_root()
|
||||
model_dir = os.path.join(models_root, model_path)
|
||||
|
||||
print(f"Run MMLU test on {model_name}.")
|
||||
accuracy_map = {
|
||||
'Mixtral-8x7B-Instruct-v0.1-fp8': 67.9,
|
||||
'Mixtral-8x7B-Instruct-v0.1-fp4': 66.9,
|
||||
}
|
||||
acc_thres = accuracy_map[model_name]
|
||||
mmlu_cmd = [
|
||||
f"{llama_example_root}/../mmlu_llmapi.py",
|
||||
f"--data_dir={mmlu_dataset_root}",
|
||||
f"--hf_model_dir={model_dir}",
|
||||
"--backend=pytorch",
|
||||
"--check_accuracy",
|
||||
"--enable_chunked_prefill",
|
||||
f"--accuracy_threshold={acc_thres}",
|
||||
f"--tp_size=2",
|
||||
]
|
||||
|
||||
venv_check_call(llm_venv, mmlu_cmd)
|
||||
@ -233,10 +233,6 @@ examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp
|
||||
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-small-128k-instruct-fp8-bfloat16]
|
||||
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16]
|
||||
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16]
|
||||
examples/test_pytorch.py::test_llm_llama_1gpu[llama-3.1-8b-enable_fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-enable_fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-enable_fp8-disable_fp4]
|
||||
examples/test_qwen.py::test_llm_qwen1_5_7b_single_gpu_lora[qwen1.5_7b_chat-Qwen1.5-7B-Chat-750Mb-lora]
|
||||
examples/test_qwen.py::test_llm_qwen1_5_moe_plugin_single_gpu_lora[qwen1.5_moe_a2.7b_chat-Upcycled-Qwen1.5-MoE2.7B-LoRA]
|
||||
examples/test_qwen.py::test_llm_qwen1_5_moe_single_gpu_lora[qwen1.5_moe_a2.7b_chat-Upcycled-Qwen1.5-MoE2.7B-LoRA]
|
||||
@ -302,123 +298,131 @@ examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-en
|
||||
examples/test_deepseek.py::test_deepseek_gpqa_llmapi[enable_overlap_scheduler-enable_cuda_graph-disable_dp-nextn2-ep8-pp1-tp8-fp4-deepseek_r1]
|
||||
|
||||
# Accuracy test list
|
||||
accuracy/test_accuracy.py::TestGpt2::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestGpt2::test_gemm_plugin
|
||||
accuracy/test_accuracy.py::TestGpt2::test_attention_ootb
|
||||
accuracy/test_accuracy.py::TestGpt2::test_context_fmha_disabled
|
||||
accuracy/test_accuracy.py::TestGpt2::test_context_fmha_fp32_acc
|
||||
accuracy/test_accuracy.py::TestGpt2::test_weight_only[int8]
|
||||
accuracy/test_accuracy.py::TestGpt2::test_weight_only[int4]
|
||||
accuracy/test_accuracy.py::TestGpt2::test_int8_kv_cache
|
||||
accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[]
|
||||
accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[per_token-per_channel]
|
||||
accuracy/test_accuracy.py::TestGpt2::test_beam_search
|
||||
accuracy/test_accuracy.py::TestGpt2::test_beam_search_large
|
||||
accuracy/test_accuracy.py::TestGpt2::test_weight_streaming_ootb
|
||||
accuracy/test_accuracy.py::TestGpt2::test_weight_streaming_plugin
|
||||
accuracy/test_accuracy.py::TestGpt2::test_cuda_graph
|
||||
accuracy/test_accuracy.py::TestGpt2Medium::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestGpt2Medium::test_fp8
|
||||
accuracy/test_accuracy.py::TestGpt2Medium::test_fp8_lm_head
|
||||
accuracy/test_accuracy.py::TestSantacoder::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestStarcoder2_3B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestStarcoder2_15B::test_smooth_quant_ootb
|
||||
accuracy/test_accuracy.py::TestGptNext::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestMinitron4BBase::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestMinitron4BBase::test_fp8
|
||||
accuracy/test_accuracy.py::TestNemotronMini4BInstruct::test_fp8_pre_quantized
|
||||
accuracy/test_accuracy.py::TestPhi2::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi2::test_tp2
|
||||
accuracy/test_accuracy.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLongAlpaca7B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLongAlpaca7B::test_multiblock_aggressive
|
||||
accuracy/test_accuracy.py::TestMamba130M::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_lookahead
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_medusa[]
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_medusa[cuda_graph]
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_eagle[]
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph]
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context]
|
||||
accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance]
|
||||
accuracy/test_accuracy.py::TestLlama7B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama7B::test_beam_search
|
||||
accuracy/test_accuracy.py::TestLlama7B::test_int4_gptq
|
||||
accuracy/test_accuracy.py::TestLlama7B::test_streamingllm
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_smooth_quant
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_2gpus[tp2]
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_2gpus[pp2]
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_2gpus[cp2]
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_tp2cp2
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_gemm_plugin
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_gemm_swiglu_plugin
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_low_latency_gemm_plugin
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_smooth_quant_ootb_tp2
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_int4_awq_tp2
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_int4_awq_pre_quantized_tp2
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_int4_gptq_pre_quantized_tp2
|
||||
accuracy/test_accuracy.py::TestLlama2_7B::test_weight_sparsity
|
||||
accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_float32
|
||||
accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only[int8]
|
||||
accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only[int4]
|
||||
accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]
|
||||
accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_pp4
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_fp8
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstructGradient1048k::test_long_context
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_autoq
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_fp8_pre_quantized
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_medusa_fp8_pre_quantized
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_smooth_quant
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_smooth_quant_ootb
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_int4_awq
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_int4_awq_manage_weights
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_pp2
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_rowwise
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_weight_streaming[1.0]
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights
|
||||
accuracy/test_accuracy.py::TestMixtral8x7B::test_nvfp4_pre_quantized
|
||||
accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestGemma2_9BIt::test_weight_only[int8]
|
||||
accuracy/test_accuracy.py::TestGemma2_9BIt::test_weight_only[int4]
|
||||
accuracy/test_accuracy.py::TestQwen1_5MoeA2_7BChat::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestQwen1_5MoeA2_7BChat::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_fp8
|
||||
accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_int4_awq_pre_quantized
|
||||
accuracy/test_accuracy.py::TestQwen2_57B_A14B::test_tp4
|
||||
accuracy/test_accuracy.py::TestQwen2_57B_A14B::test_tp2pp2
|
||||
accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_gemm_plugin
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_attention_ootb
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_context_fmha_disabled
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_context_fmha_fp32_acc
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8]
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4]
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_int8_kv_cache
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[]
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token-per_channel]
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_beam_search
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_ootb
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_plugin
|
||||
accuracy/test_cli_flow.py::TestGpt2::test_cuda_graph
|
||||
accuracy/test_cli_flow.py::TestGpt2Medium::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8
|
||||
accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head
|
||||
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestStarcoder2_3B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestStarcoder2_15B::test_smooth_quant_ootb
|
||||
accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestMinitron4BBase::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestMinitron4BBase::test_fp8
|
||||
accuracy/test_cli_flow.py::TestNemotronMini4BInstruct::test_fp8_prequantized
|
||||
accuracy/test_cli_flow.py::TestPhi2::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi2::test_tp2
|
||||
accuracy/test_cli_flow.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive
|
||||
accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[]
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph]
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[]
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph]
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context]
|
||||
accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance]
|
||||
accuracy/test_cli_flow.py::TestLlama7B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama7B::test_beam_search
|
||||
accuracy/test_cli_flow.py::TestLlama7B::test_int4_gptq
|
||||
accuracy/test_cli_flow.py::TestLlama7B::test_streamingllm
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_smooth_quant
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_2gpus[tp2]
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_2gpus[pp2]
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_2gpus[cp2]
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_tp2cp2
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_plugin
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_swiglu_plugin
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_low_latency_gemm_plugin
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_smooth_quant_ootb_tp2
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_awq_tp2
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_awq_prequantized_tp2
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_gptq_prequantized_tp2
|
||||
accuracy/test_cli_flow.py::TestLlama2_7B::test_weight_sparsity
|
||||
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_float32
|
||||
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int8]
|
||||
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int4]
|
||||
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]
|
||||
accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_pp4
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_autoq
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_fp8_prequantized
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_medusa_fp8_prequantized
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant_ootb
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_int4_awq
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_int4_awq_manage_weights
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_pp2
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_rowwise
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_weight_streaming[1.0]
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights
|
||||
accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized
|
||||
accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int8]
|
||||
accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int4]
|
||||
accuracy/test_cli_flow.py::TestQwen1_5MoeA2_7BChat::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestQwen1_5MoeA2_7BChat::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized
|
||||
accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp4
|
||||
accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp2pp2
|
||||
accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_fp8
|
||||
accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2
|
||||
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4
|
||||
|
||||
test_e2e.py::test_benchmark_sanity[bert_base] # 127.18s
|
||||
test_e2e.py::test_benchmark_sanity[gpt_350m] # 64.06s
|
||||
@ -465,10 +469,6 @@ test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Me
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
test_e2e.py::test_trtllm_bench_mgmn
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_4gpus[Llama-3.3-70B-Instruct-fp8-modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_4gpus[Llama-3.3-70B-Instruct-fp4-modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_2gpus[Mixtral-8x7B-Instruct-v0.1-fp8-modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_2gpus[Mixtral-8x7B-Instruct-v0.1-fp4-modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4]
|
||||
examples/test_medusa.py::test_codellama_medusa_1gpu[CodeLlama-7b-Instruct]
|
||||
examples/test_medusa.py::test_llama_medusa_1gpu[llama-v2-7b-hf]
|
||||
examples/test_medusa.py::test_llama_medusa_1gpu[llama-3.2-1b]
|
||||
|
||||
@ -103,29 +103,36 @@ test_e2e.py::test_openai_multi_chat_example
|
||||
test_e2e.py::test_openai_consistent_chat
|
||||
|
||||
# Accuracy test list
|
||||
accuracy/test_accuracy.py::TestStarcoder2_3B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestMinitron4BBase::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_fp8
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8B::test_autoq
|
||||
accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_medusa_fp8_pre_quantized
|
||||
accuracy/test_accuracy.py::TestLlama3_2_1B::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
accuracy/test_accuracy.py::TestGemma2_9BIt::test_fp8
|
||||
accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_auto_dtype
|
||||
accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_fp8
|
||||
accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_int4_awq_pre_quantized
|
||||
accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestStarcoder2_3B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestMinitron4BBase::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8B::test_autoq
|
||||
accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_medusa_fp8_prequantized
|
||||
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
accuracy/test_cli_flow.py::TestGemma2_9BIt::test_fp8
|
||||
accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_fp8
|
||||
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized
|
||||
accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
|
||||
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2
|
||||
accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales
|
||||
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4
|
||||
|
||||
# Pivot to Pytorch test cases.
|
||||
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
@ -138,13 +145,6 @@ test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B]
|
||||
test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_4gpus[Llama-3.3-70B-Instruct-fp8-modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp8]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_4gpus[Llama-3.3-70B-Instruct-fp4-modelopt-hf-model-hub/Llama-3.3-70B-Instruct-fp4]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_2gpus[Mixtral-8x7B-Instruct-v0.1-fp8-modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp8]
|
||||
examples/test_pytorch.py::test_mmlu_llmapi_2gpus[Mixtral-8x7B-Instruct-v0.1-fp4-modelopt-hf-model-hub/Mixtral-8x7B-Instruct-v0.1-fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-enable_fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4]
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-enable_fp8-disable_fp4]
|
||||
|
||||
# PyTorch flow disaggregated tests
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
|
||||
|
||||
@ -45,7 +45,7 @@ l0_a10:
|
||||
- test_e2e.py::test_trtllm_bench_latency_sanity[FP16-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency]
|
||||
- unittest/trt/quantization # 18 mins
|
||||
- accuracy/test_accuracy.py::TestLlama7B::test_streamingllm # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama7B::test_streamingllm # 2 mins
|
||||
- unittest/trt/functional # 37 mins
|
||||
- test_cache.py::test_cache_sanity # 1 sec
|
||||
- test_e2e.py::test_llmapi_quickstart_atexit
|
||||
@ -55,22 +55,22 @@ l0_a10:
|
||||
- unittest/llmapi/test_llm_perf_evaluator.py
|
||||
- unittest/llmapi/test_build_cache.py
|
||||
- unittest/llmapi/test_llm_utils.py
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_auto_dtype # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_beam_search # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_beam_search_large # 6 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_weight_streaming_ootb # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_cuda_graph # 1 min
|
||||
- accuracy/test_accuracy.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestMamba130M::test_auto_dtype # 1 min
|
||||
- accuracy/test_accuracy.py::TestLongAlpaca7B::test_multiblock_aggressive # 6 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_lookahead # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_medusa[] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_medusa[cuda_graph] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search_large # 6 mins
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_ootb # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_cuda_graph # 1 min
|
||||
- accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype # 1 min
|
||||
- accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive # 6 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_auto_dtype
|
||||
- examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-disable_weight_only]
|
||||
- unittest/trt/attention/test_gpt_attention_IFB.py
|
||||
- unittest/trt/attention/test_gpt_attention_no_cache.py
|
||||
@ -111,14 +111,14 @@ l0_a10:
|
||||
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba-130m-float16-disable_gemm_plugin] # 3 mins
|
||||
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-disable_gemm_plugin]
|
||||
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-disable_gemm_plugin] # 4 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_context_fmha_disabled # 1 min
|
||||
- accuracy/test_accuracy.py::TestLlama7B::test_auto_dtype # 2 mins
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_accuracy.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi2::test_auto_dtype # 2 mins
|
||||
- accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_context_fmha_disabled # 1 min
|
||||
- accuracy/test_cli_flow.py::TestLlama7B::test_auto_dtype # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_cli_flow.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi2::test_auto_dtype # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- examples/test_mamba.py::test_llm_mamba_1gpu[mamba-130m-float16-enable_gemm_plugin] # 2 mins
|
||||
- test_e2e.py::test_llmapi_load_engine_from_build_command[llama-codellama/CodeLlama-7b-Instruct-hf] # 5min
|
||||
- test_e2e.py::test_llmapi_load_ckpt_from_convert_command # 5min
|
||||
|
||||
@ -16,19 +16,19 @@ l0_a100:
|
||||
- unittest/llmapi/test_llm.py -m "part0"
|
||||
- unittest/llmapi/test_llm.py -m "not part0"
|
||||
- unittest/llmapi/test_executor.py
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_int8_kv_cache # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_weight_only[int4] # 1 min
|
||||
- accuracy/test_accuracy.py::TestStarcoder2_15B::test_smooth_quant_ootb
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_smooth_quant_ootb
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_smooth_quant
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only[int8]
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only[int4]
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_weight_only_manage_weights[int4]
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_smooth_quant_ootb_manage_weights
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_int8_gptq
|
||||
- accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_int4_awq_pre_quantized
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_int8_kv_cache # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4] # 1 min
|
||||
- accuracy/test_cli_flow.py::TestStarcoder2_15B::test_smooth_quant_ootb
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant_ootb
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int8]
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only[int4]
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_int8_kv_cache[int8]
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_weight_only_manage_weights[int4]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_smooth_quant_ootb_manage_weights
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_int8_gptq
|
||||
- accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_int4_awq_prequantized
|
||||
- condition:
|
||||
ranges:
|
||||
system_gpu_count:
|
||||
@ -41,9 +41,9 @@ l0_a100:
|
||||
terms:
|
||||
stage: post_merge
|
||||
tests:
|
||||
- accuracy/test_accuracy.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_weight_only[int8] # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8] # 1 min
|
||||
- unittest/trt/model_api/test_model_level_api.py
|
||||
- unittest/trt/model_api/test_model_quantization.py
|
||||
- unittest/trt/model_api/test_model_api_multi_gpu.py
|
||||
|
||||
@ -36,16 +36,16 @@ l0_a30:
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[Qwen2-VL-7B-Instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:4]
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
|
||||
- accuracy/test_accuracy.py::TestMinitron4BBase::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_weight_sparsity # 4 mins
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_medusa_fp8_pre_quantized
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_weight_streaming[0.1]
|
||||
- accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestQwen2_0_5BInstruct::test_weight_only
|
||||
- accuracy/test_accuracy.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
- accuracy/test_cli_flow.py::TestMinitron4BBase::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_weight_sparsity # 4 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_medusa_fp8_prequantized
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_weight_streaming[0.1]
|
||||
- accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi3Mini4kInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi3Mini128kInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestQwen2_0_5BInstruct::test_weight_only
|
||||
- accuracy/test_cli_flow.py::TestQwen2_5_1_5BInstruct::test_weight_only
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8]
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:2-disable_fp8]
|
||||
@ -97,8 +97,8 @@ l0_a30:
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-disable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] # 1 mins
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-float32-disable_gemm_plugin-disable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] # 1 mins
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1-disable_fp8] # 1 mins
|
||||
- accuracy/test_accuracy.py::TestLlama7B::test_manage_weights # 2 mins
|
||||
- accuracy/test_accuracy.py::TestQwen2_7BInstruct::test_weight_only
|
||||
- accuracy/test_cli_flow.py::TestLlama7B::test_manage_weights # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestQwen2_7BInstruct::test_weight_only
|
||||
- examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] # 15 mins
|
||||
- examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization] # 5 mins
|
||||
- examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] # 6 mins
|
||||
|
||||
@ -11,11 +11,13 @@ l0_b200:
|
||||
linux_distribution_name: ubuntu*
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_nvfp4
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
|
||||
- test_e2e.py::test_ptq_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
- examples/test_pytorch.py::test_llm_llama_1gpu[llama-3.1-8b-enable_fp4]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)"
|
||||
- unittest/_torch -k "modeling_llama"
|
||||
@ -24,17 +26,15 @@ l0_b200:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and not nextn0"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- unittest/_torch/speculative/test_eagle3.py
|
||||
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-enable_fp4]
|
||||
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4]
|
||||
# ------------- TRT tests ---------------
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstruct::test_fp8
|
||||
- accuracy/test_accuracy.py::TestMixtral8x7B::test_nvfp4_pre_quantized
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[disable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-disable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_nvfp4_gemm_plugin[enable_norm_quant_fusion-enable_fused_quant]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstruct::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized
|
||||
- unittest/trt/attention/test_gpt_attention.py -k "trtllm_gen"
|
||||
- unittest/llmapi/test_llm_quant.py # 3.5 mins on B200
|
||||
- unittest/trt/functional/test_fp4_gemm.py # 3 mins on B200
|
||||
|
||||
@ -32,16 +32,16 @@ l0_dgx_h100:
|
||||
- test_cpp.py::test_multi_gpu_trt_gpt_real_decoder[90]
|
||||
- test_cpp.py::test_multi_gpu_disagg[90]
|
||||
# ------------- TRT tests ---------------
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_tp2[disable_reduce_fusion-disable_fp8_context_fmha]
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_tp2[enable_reduce_fusion-enable_fp8_context_fmha]
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_pp4
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_2gpus[cp2]
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_tp2cp2
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_2gpus[pp2] # 2 mins
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[enable_gemm_allreduce_plugin]
|
||||
- accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
- accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights
|
||||
- accuracy/test_accuracy.py::TestQwen2_57B_A14B::test_tp4
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_tp2[disable_reduce_fusion-disable_fp8_context_fmha]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_tp2[enable_reduce_fusion-enable_fp8_context_fmha]
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_pp4
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_2gpus[cp2]
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_tp2cp2
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_2gpus[pp2] # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[enable_gemm_allreduce_plugin]
|
||||
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2
|
||||
- accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights
|
||||
- accuracy/test_cli_flow.py::TestQwen2_57B_A14B::test_tp4
|
||||
- examples/test_llama.py::test_llm_llama_long_alpaca_8gpu_summary[pg64317-tp4pp2-nb:4]
|
||||
- examples/test_llama.py::test_llm_llama_v2_lora_benchmark_2gpu[chinese_lora-llama-v2-13b-hf]
|
||||
- examples/test_mixtral.py::test_llm_mixtral_moe_plugin_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora]
|
||||
@ -80,16 +80,16 @@ l0_dgx_h100:
|
||||
- examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion]
|
||||
- examples/test_recurrentgemma.py::test_llm_recurrentgemma_2gpu[recurrentgemma-2b]
|
||||
- examples/test_mamba.py::test_llm_mamba2_2gpu[mamba-codestral-7B-v0.1]
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_smooth_quant_ootb_tp2
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_int4_awq_tp2
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_int4_awq_pre_quantized_tp2
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_int4_gptq_pre_quantized_tp2
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plugin]
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_tp2[disable_reduce_fusion-enable_fp8_context_fmha]
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_fp8_tp2[enable_reduce_fusion-disable_fp8_context_fmha]
|
||||
- accuracy/test_accuracy.py::TestPhi2::test_tp2
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_smooth_quant_ootb_tp2
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_awq_tp2
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_awq_prequantized_tp2
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_int4_gptq_prequantized_tp2
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[disable_gemm_allreduce_plugin]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_tp4[enable_gemm_allreduce_plugin]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_tp4[disable_gemm_allreduce_plugin]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_tp2[disable_reduce_fusion-enable_fp8_context_fmha]
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_fp8_tp2[enable_reduce_fusion-disable_fp8_context_fmha]
|
||||
- accuracy/test_cli_flow.py::TestPhi2::test_tp2
|
||||
- test_e2e.py::test_llmapi_quant_llama_70b
|
||||
- test_e2e.py::test_llmapi_example_distributed_autopp_tp2
|
||||
- examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:2-pp:2-nb:1-disable_fp8]
|
||||
|
||||
@ -33,8 +33,8 @@ l0_gh200:
|
||||
stage: post_merge
|
||||
tests:
|
||||
- unittest/test_model_runner_cpp.py
|
||||
- accuracy/test_accuracy.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1]
|
||||
- examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1]
|
||||
- unittest/trt/model/eagle
|
||||
|
||||
@ -18,10 +18,10 @@ l0_h100:
|
||||
- unittest/_torch/modeling -k "modeling_mixtral"
|
||||
- unittest/_torch/modeling -k "modeling_nemotron"
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and nextn0"
|
||||
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B::test_auto_dtype
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
|
||||
- examples/test_pytorch.py::test_llm_llama_1gpu[llama-3.1-8b-disable_fp4]
|
||||
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-enable_fp8-disable_fp4]
|
||||
# ------------- CPP tests ---------------
|
||||
- test_cpp.py::test_unit_tests[90]
|
||||
- test_cpp.py::test_model[fp8-llama-90]
|
||||
@ -46,9 +46,9 @@ l0_h100:
|
||||
- test_e2e.py::test_trtllm_bench_iteration_log[TRT-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_trtllm_bench_iteration_log[TRT-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B]
|
||||
- accuracy/test_accuracy.py::TestLongAlpaca7B::test_multiblock_aggressive # 6 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_medusa[] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_medusa[cuda_graph] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive # 6 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph] # 5 mins
|
||||
- examples/test_llama.py::test_llama_3_x_fp8_with_bf16_lora[llama-3.1-8b]
|
||||
- examples/test_llama.py::test_llama_3_x_fp8_with_bf16_lora[llama-3.2-1b]
|
||||
- examples/test_qwen.py::test_llm_hf_qwen_multi_lora_1gpu[qwen2.5_1.5b_instruct]
|
||||
@ -68,11 +68,11 @@ l0_h100:
|
||||
- unittest/trt/model/test_mamba.py # 3 mins
|
||||
- examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8]
|
||||
- examples/test_medusa.py::test_llm_medusa_1gpu[use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8]
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_fp8_pre_quantized
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_gemm_plugin
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_gemm_swiglu_plugin
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8_low_latency_gemm_plugin
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_plugin
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_gemm_swiglu_plugin
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8_low_latency_gemm_plugin
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1]
|
||||
- examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False]
|
||||
@ -119,35 +119,35 @@ l0_h100:
|
||||
- examples/test_granite.py::test_granite_bf16_lora[granite-3.0-2b-instruct]
|
||||
- examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct]
|
||||
- examples/test_llama.py::test_llama_3_x_fp8_with_bf16_lora[llama-v3-8b-instruct-hf]
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_auto_dtype # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_beam_search # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_weight_streaming_ootb # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_cuda_graph # 1 min
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_context_fmha_disabled # 1 min
|
||||
- accuracy/test_accuracy.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_accuracy.py::TestMamba130M::test_auto_dtype # 1 min
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_lookahead # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance] # 5 mins
|
||||
- accuracy/test_accuracy.py::TestPhi2::test_auto_dtype # 2 mins
|
||||
- accuracy/test_accuracy.py::TestGpt2Medium::test_fp8
|
||||
- accuracy/test_accuracy.py::TestGpt2Medium::test_fp8_lm_head
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_accuracy.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestMinitron4BBase::test_fp8
|
||||
- accuracy/test_accuracy.py::TestNemotronMini4BInstruct::test_fp8_pre_quantized
|
||||
- accuracy/test_accuracy.py::TestTinyLlama1_1BChat::test_float32
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstructGradient1048k::test_long_context
|
||||
- accuracy/test_accuracy.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_autoq
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
- accuracy/test_accuracy.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
- accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_auto_dtype # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_beam_search # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_weight_streaming_ootb # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_cuda_graph # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_context_fmha_disabled # 1 min
|
||||
- accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype # 1.5 mins
|
||||
- accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype # 1 min
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-typical_acceptance] # 5 mins
|
||||
- accuracy/test_cli_flow.py::TestPhi2::test_auto_dtype # 2 mins
|
||||
- accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestGpt2Medium::test_fp8_lm_head
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_cli_flow.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestMinitron4BBase::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestNemotronMini4BInstruct::test_fp8_prequantized
|
||||
- accuracy/test_cli_flow.py::TestTinyLlama1_1BChat::test_float32
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context
|
||||
- accuracy/test_cli_flow.py::TestLlama3_8BInstructGradient1048k::test_long_context_ppl
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_autoq
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache
|
||||
- accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache_beam_search
|
||||
- accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- examples/test_gpt.py::test_llm_minitron_fp8_with_pseudo_loras[4b]
|
||||
- examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-disable_weight_only]
|
||||
- unittest/trt/model_api/test_model_quantization.py # 20 mins on H100
|
||||
|
||||
@ -33,11 +33,11 @@ l0_l40s:
|
||||
- unittest/trt/attention/test_gpt_attention.py -k "xqa_generic"
|
||||
- unittest/trt/quantization # 18 mins
|
||||
- unittest/trt/functional # 37 mins
|
||||
- accuracy/test_accuracy.py::TestLlama2_7B::test_fp8
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise
|
||||
- accuracy/test_accuracy.py::TestLlama3_1_8B::test_fp8_rowwise_meta_recipe
|
||||
- accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8
|
||||
- accuracy/test_cli_flow.py::TestLlama3_1_8B::test_fp8_rowwise_meta_recipe
|
||||
- accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
- accuracy/test_llm_api.py::TestLlama3_1_8B::test_fp8_rowwise
|
||||
- examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_fp16]
|
||||
- examples/test_llama.py::test_llm_llama_v3_dora_1gpu[commonsense-llama-v3-8b-dora-r32-llama-v3-8b-hf-base_fp16]
|
||||
- examples/test_llama.py::test_llm_llama_v1_1gpu_kv_cache_reuse_with_prompt_table[llama-7b]
|
||||
@ -68,11 +68,11 @@ l0_l40s:
|
||||
terms:
|
||||
stage: post_merge
|
||||
tests:
|
||||
- accuracy/test_accuracy.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_accuracy.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi2::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
- accuracy/test_accuracy.py::TestQwen1_5MoeA2_7BChat::test_weight_only
|
||||
- accuracy/test_cli_flow.py::TestGpt2::test_attention_ootb
|
||||
- accuracy/test_cli_flow.py::TestStarcoder2_3B::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi2::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
- accuracy/test_cli_flow.py::TestQwen1_5MoeA2_7BChat::test_weight_only
|
||||
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_cpp_session-tp1] # 10 mins
|
||||
- examples/test_gpt.py::test_llm_gpt2_next_prompt_tuning[use_py_session-tp1]
|
||||
- examples/test_multimodal.py::test_llm_multimodal_general[neva-22b-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1]
|
||||
|
||||
@ -87,10 +87,10 @@ full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_sessi
|
||||
full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[per_token-per_channel] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token-per_channel] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-disable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
@ -110,9 +110,9 @@ full:B200_PCIe/test_e2e.py::test_benchmark_sanity[bert_base] SKIP (Disable for B
|
||||
full:B200_PCIe/test_e2e.py::test_benchmark_sanity[roberta_base] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/unittest/trt/functional SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/unittest/trt/quantization SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_medusa[] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_medusa[cuda_graph] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph] SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/unittest/trt/attention/test_bert_attention.py SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/unittest/trt/model/test_mamba.py SKIP (Disable for Blackwell)
|
||||
full:B200_PCIe/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-mini-128k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell)
|
||||
@ -154,9 +154,9 @@ full:B200_PCIe/examples/test_qwen.py::test_llm_qwen_7b_single_gpu_summary[Qwen2.
|
||||
full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask)
|
||||
full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask)
|
||||
full:B200_PCIe/examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200_PCIe/unittest/llmapi/test_llm_models.py -m "part0" SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 80/96)
|
||||
full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96)
|
||||
@ -194,10 +194,10 @@ full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-re
|
||||
full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding)
|
||||
full:B200/accuracy/test_accuracy.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestGpt2::test_smooth_quant[per_token-per_channel] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token-per_channel] SKIP (Disable for Blackwell)
|
||||
full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-bfloat16-enable_gemm_plugin-enable_attention_plugin-disable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-bart-large-cnn-float16-disable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:1-pp:1-nb:1] SKIP (Disable for Blackwell)
|
||||
@ -220,9 +220,9 @@ full:B200/test_e2e.py::test_benchmark_sanity[bert_base] SKIP (Disable for Blackw
|
||||
full:B200/test_e2e.py::test_benchmark_sanity[roberta_base] SKIP (Disable for Blackwell)
|
||||
full:B200/unittest/trt/functional SKIP (Disable for Blackwell)
|
||||
full:B200/unittest/trt/quantization SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_medusa[] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_medusa[cuda_graph] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph] SKIP (Disable for Blackwell)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell)
|
||||
full:B200/unittest/trt/attention/test_bert_attention.py SKIP (Disable for Blackwell)
|
||||
full:B200/unittest/trt/model/test_mamba.py SKIP (Disable for Blackwell)
|
||||
full:B200/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-mini-128k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell)
|
||||
@ -264,9 +264,9 @@ full:B200/examples/test_qwen.py::test_llm_qwen_7b_single_gpu_summary[Qwen2.5-1.5
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask)
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask)
|
||||
full:B200/examples/test_eagle.py::test_llm_eagle_1gpu[EAGLE-Vicuna-7B-v1.3-float16-bs1] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_accuracy.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph-chunked_context] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell for Speculative Dec)
|
||||
full:B200/unittest/llmapi/test_llm_models.py -m "part0" SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 80/96)
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96)
|
||||
@ -339,13 +339,13 @@ full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoderplu
|
||||
full:B200/examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b] SKIP (No available XQA kernels are found for speculative decoding mode)
|
||||
full:B200/examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b] SKIP (No available XQA kernels are found for speculative decoding mode)
|
||||
full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs1] SKIP (No available XQA kernels are found for speculative decoding mode)
|
||||
full:B200/examples/accuracy/test_accuracy.py::TestLlama3_1_8BInstruct::test_medusa_fp8_pre_quantized SKIP (No available XQA kernels are found for speculative decoding mode)
|
||||
full:B200/examples/accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_medusa_fp8_prequantized SKIP (No available XQA kernels are found for speculative decoding mode)
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:2-bfloat16-bs:1-nb:1] SKIP (Only Context FMHA supports custom mask input currently)
|
||||
full:B200/test_e2e.py::test_llmapi_load_engine_from_build_command[falcon-falcon-7b-instruct] SKIP (Not supported on B200)
|
||||
full:B200/examples/test_qwen.py::test_llm_qwen_smooth_quant_single_gpu_summary[qwen_7b_chat-enable_ptpc-nb:4] SKIP (Not supported on B200)
|
||||
full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (not support on B200)
|
||||
full:B200/examples/test_mixtral.py::test_llm_mixtral_moe_plugin_fp8_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora] SKIP (https://nvbugs/5064768)
|
||||
full:B200/accuracy/test_accuracy.py::TestGpt2::test_int8_kv_cache SKIP (not support on B200)
|
||||
full:B200/accuracy/test_cli_flow.py::TestGpt2::test_int8_kv_cache SKIP (not support on B200)
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-nb:1] SKIP (Only Context FMHA supports custom mask input currently)
|
||||
full:B200/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:8-nb:1] SKIP (Only Context FMHA supports custom mask input currently)
|
||||
full:B200/examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int8_sq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (not support on B200)
|
||||
@ -391,7 +391,6 @@ full:B200/test_e2e.py::test_ptp_quickstart_advanced[Nemotron4_4B-BF16-nemotron/M
|
||||
full:B200/test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5136994)
|
||||
full:B200/test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] SKIP (https://nvbugs/5136994)
|
||||
examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5141288)
|
||||
examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4] SKIP (https://nvbugs/5141289)
|
||||
examples/test_qwen.py::test_llm_qwen_7b_multi_gpus_summary[qwen2_vl_7b_instruct-enable_fmha_fp32_acc-enable_plugin-tp2pp2-nb:4] SKIP (https://nvbugs/5141290)
|
||||
examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_vl_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-disable_weight_only-disable_fmha] SKIP (https://nvbugs/5141290)
|
||||
examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_vl_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha_fp32_acc] SKIP (https://nvbugs/5141290)
|
||||
@ -408,12 +407,12 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_att
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5155144)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5155144)
|
||||
|
||||
full:L40S/accuracy/test_accuracy.py::TestGemma2_9BIt::test_auto_dtype SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_accuracy.py::TestGemma2_9BIt::test_weight_only[int8] SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_accuracy.py::TestGemma2_9BIt::test_weight_only[int4] SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_accuracy.py::TestLlama2_7B::test_fp8 SKIP (https://nvbugs/5176867)
|
||||
full:L40S/accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2 SKIP (https://nvbugs/5176867)
|
||||
full:L40S/accuracy/test_accuracy.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights SKIP (https://nvbugs/5176867)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_auto_dtype SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int8] SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestGemma2_9BIt::test_weight_only[int4] SKIP (https://nvbugs/5176851)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8 SKIP (https://nvbugs/5176867)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2 SKIP (https://nvbugs/5176867)
|
||||
full:L40S/accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights SKIP (https://nvbugs/5176867)
|
||||
|
||||
full:B200/perf/test_perf.py::test_perf[quant:w4a8_awq] SKIP (https://nvbugspro.nvidia.com/bug/5161074)
|
||||
full:B200/perf/test_perf.py::test_perf[quant:int8_sq_per_tensor] SKIP (https://nvbugspro.nvidia.com/bug/5161074)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user