mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][infra] Add LongBenchV1 to trtllm-eval. (#10265)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
6732c76414
commit
1f0365da36
@ -28,11 +28,12 @@ jieba==0.42.1
|
||||
rouge==1.0.1
|
||||
pytest-rerunfailures
|
||||
ruff==0.9.4
|
||||
lm_eval[api]==0.4.8
|
||||
lm_eval[api]==0.4.9.2
|
||||
docstring_parser
|
||||
genai-perf==0.0.13
|
||||
opentelemetry-sdk>=1.26.0
|
||||
opentelemetry-api>=1.26.0
|
||||
opentelemetry-exporter-otlp>=1.26.0
|
||||
opentelemetry-semantic-conventions-ai>=0.4.1
|
||||
fuzzywuzzy==0.18.0
|
||||
aiperf==0.3.0
|
||||
|
||||
@ -21,7 +21,8 @@ import tensorrt_llm.profiler as profiler
|
||||
from .. import LLM as PyTorchLLM
|
||||
from .._tensorrt_engine import LLM
|
||||
from ..evaluate import (GSM8K, MMLU, MMMU, CnnDailymail, GPQADiamond,
|
||||
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV2)
|
||||
GPQAExtended, GPQAMain, JsonModeEval, LongBenchV1,
|
||||
LongBenchV2)
|
||||
from ..llmapi import BuildConfig, KvCacheConfig
|
||||
from ..llmapi.llm_utils import update_llm_args_with_extra_options
|
||||
from ..logger import logger, severity_map
|
||||
@ -184,6 +185,7 @@ main.add_command(GPQAMain.command)
|
||||
main.add_command(GPQAExtended.command)
|
||||
main.add_command(JsonModeEval.command)
|
||||
main.add_command(MMMU.command)
|
||||
main.add_command(LongBenchV1.command)
|
||||
main.add_command(LongBenchV2.command)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -15,11 +15,12 @@
|
||||
|
||||
from .cnn_dailymail import CnnDailymail
|
||||
from .json_mode_eval import JsonModeEval
|
||||
from .lm_eval import GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain
|
||||
from .lm_eval import (GSM8K, MMMU, GPQADiamond, GPQAExtended, GPQAMain,
|
||||
LongBenchV1)
|
||||
from .longbench_v2 import LongBenchV2
|
||||
from .mmlu import MMLU
|
||||
|
||||
__all__ = [
|
||||
"CnnDailymail", "MMLU", "GSM8K", "GPQADiamond", "GPQAMain", "GPQAExtended",
|
||||
"JsonModeEval", "MMMU", "LongBenchV2"
|
||||
"JsonModeEval", "MMMU", "LongBenchV1", "LongBenchV2"
|
||||
]
|
||||
|
||||
@ -100,10 +100,23 @@ class LmEvalWrapper(TemplateLM):
|
||||
"max_gen_toks": "max_tokens",
|
||||
"until": "stop",
|
||||
}
|
||||
# IMPORTANT:
|
||||
# lm-evaluation-harness controls generation primarily via per-task gen_kwargs.
|
||||
# For example, the `local-completions` model wrapper uses:
|
||||
# max_tokens <- gen_kwargs["max_tokens"] or gen_kwargs["max_gen_toks"] or _max_gen_toks
|
||||
# temperature <- gen_kwargs.get("temperature", 0)
|
||||
# stop <- gen_kwargs.get("until", ...)
|
||||
# See: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/openai_completions.py
|
||||
|
||||
if self.sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=gen_kwargs.get("max_gen_toks", 256),
|
||||
temperature=gen_kwargs.get("temperature", 0),
|
||||
stop=gen_kwargs.get("until", None),
|
||||
)
|
||||
else:
|
||||
sampling_params = copy.deepcopy(self.sampling_params)
|
||||
|
||||
for lm_eval_key, trtllm_key in params_mapping.items():
|
||||
value = gen_kwargs.pop(lm_eval_key, None)
|
||||
if value is not None:
|
||||
@ -714,3 +727,156 @@ class MMMU(LmEvalEvaluator):
|
||||
kwargs[
|
||||
"stop"] = "<|endoftext|>" # NOTE: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/mmmu/_template_yaml#L10
|
||||
MMMU.command_harness(ctx, **kwargs)
|
||||
|
||||
|
||||
class LongBenchV1(LmEvalEvaluator):
|
||||
"""
|
||||
LongBench v1 evaluation via lm-evaluation-harness.
|
||||
|
||||
Notes:
|
||||
- In lm-eval, `longbench` is typically a *group task* that expands into many
|
||||
subtasks. The base `LmEvalEvaluator.evaluate()` assumes a single task
|
||||
key exists in `results["results"][task_name]`, so we override evaluation
|
||||
to aggregate over subtasks.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__("longbench", **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _flatten_task_dict(task_dict: dict) -> List[str]:
|
||||
names: List[str] = []
|
||||
for k, v in task_dict.items():
|
||||
if isinstance(v, dict):
|
||||
names.extend(LongBenchV1._flatten_task_dict(v))
|
||||
else:
|
||||
names.append(k)
|
||||
return names
|
||||
|
||||
@staticmethod
|
||||
def _get_group_score(metrics: Dict[str, Any],
|
||||
*,
|
||||
preferred_filter: str = "none") -> Optional[float]:
|
||||
"""
|
||||
lm-eval stores group metrics as "<metric>,<filter>" (e.g., "score,none").
|
||||
Prefer "score,none" (matches printed table), otherwise accept any
|
||||
"score,<filter>" key.
|
||||
"""
|
||||
if not isinstance(metrics, dict):
|
||||
return None
|
||||
|
||||
preferred_key = f"score,{preferred_filter}"
|
||||
v = metrics.get(preferred_key, None)
|
||||
if isinstance(v, (int, float)):
|
||||
return float(v)
|
||||
|
||||
return None
|
||||
|
||||
def evaluate(self,
|
||||
llm: Union[LLM, PyTorchLLM],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
streaming: bool = False) -> float:
|
||||
import lm_eval
|
||||
|
||||
lm_cls = MultimodalLmEvalWrapper if self.MULTIMODAL else LmEvalWrapper
|
||||
results = lm_eval.evaluate(
|
||||
lm=lm_cls(llm,
|
||||
sampling_params=sampling_params,
|
||||
streaming=streaming,
|
||||
chat_template_kwargs=self.chat_template_kwargs),
|
||||
task_dict=self.task_dict,
|
||||
limit=self.num_samples,
|
||||
apply_chat_template=self.apply_chat_template,
|
||||
fewshot_as_multiturn=self.fewshot_as_multiturn,
|
||||
system_instruction=self.system_prompt)
|
||||
|
||||
logger.info(
|
||||
f"lm-eval {self.task_name} results:\n{lm_eval.utils.make_table(results)}"
|
||||
)
|
||||
|
||||
# LongBench is a group task in lm-eval. lm-eval already computes subgroup
|
||||
# "score" values (e.g., `longbench_fewshot`, `longbench_single`, ...).
|
||||
# To keep this implementation simple and aligned with the printed table,
|
||||
# we compute the final LongBench score as the unweighted mean of subgroup
|
||||
# scores.
|
||||
group_results: Dict[str, Dict[str, Any]] = results.get("groups", {})
|
||||
subgroup_names = results.get("group_subtasks",
|
||||
{}).get(self.task_name, [])
|
||||
if not subgroup_names:
|
||||
raise KeyError(
|
||||
f"lm-eval did not provide subgroup list for group '{self.task_name}'. "
|
||||
"Expected `results['group_subtasks'][task_name]` to exist.")
|
||||
|
||||
subgroup_scores: List[float] = []
|
||||
missing: List[str] = []
|
||||
for name in subgroup_names:
|
||||
m = group_results.get(name, None)
|
||||
score = self._get_group_score(m)
|
||||
if score is None:
|
||||
missing.append(name)
|
||||
else:
|
||||
subgroup_scores.append(score)
|
||||
|
||||
if not subgroup_scores:
|
||||
raise KeyError(
|
||||
f"lm-eval did not provide subgroup 'score' metrics for '{self.task_name}'. "
|
||||
f"Missing subgroups: {missing[:10]}")
|
||||
|
||||
result_acc = float(np.mean(subgroup_scores)) * 100
|
||||
logger.info(
|
||||
f"lm-eval {self.task_name} average 'score' across {len(subgroup_scores)} subgroups: {result_acc:.2f}"
|
||||
)
|
||||
return result_acc
|
||||
|
||||
@click.command("longbench_v1")
|
||||
@click.option(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"The path to LongBench dataset. If unspecified, the dataset is downloaded from HF hub."
|
||||
)
|
||||
@click.option(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of samples to run the evaluation; None means full dataset."
|
||||
)
|
||||
@click.option("--random_seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Random seed for dataset processing.")
|
||||
@click.option("--apply_chat_template",
|
||||
type=click.BOOL,
|
||||
default=True,
|
||||
show_default=True,
|
||||
help="Whether to apply chat template.")
|
||||
@click.option(
|
||||
"--chat_template_kwargs",
|
||||
type=str,
|
||||
default=None,
|
||||
callback=lambda ctx, param, value: json.loads(value) if value else None,
|
||||
help=
|
||||
'Chat template kwargs as JSON string, e.g., \'{"thinking_budget": 0}\'')
|
||||
@click.option("--system_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="System prompt.")
|
||||
@click.pass_context
|
||||
@staticmethod
|
||||
def command(ctx, **kwargs) -> None:
|
||||
llm: Union[LLM, PyTorchLLM] = ctx.obj
|
||||
|
||||
evaluator = LongBenchV1(
|
||||
dataset_path=kwargs.pop("dataset_path", None),
|
||||
num_samples=kwargs.pop("num_samples", None),
|
||||
random_seed=kwargs.pop("random_seed", 0),
|
||||
apply_chat_template=kwargs.pop("apply_chat_template", True),
|
||||
system_prompt=kwargs.pop("system_prompt", None),
|
||||
chat_template_kwargs=kwargs.pop("chat_template_kwargs", None))
|
||||
|
||||
# Let lm-eval task configs control sampling via gen_kwargs.
|
||||
sampling_params = None
|
||||
|
||||
evaluator.evaluate(llm, sampling_params)
|
||||
llm.shutdown()
|
||||
|
||||
@ -456,6 +456,34 @@ class LongBenchV2(AccuracyTask):
|
||||
)
|
||||
|
||||
|
||||
class LongBenchV1(AccuracyTask):
|
||||
DATASET = "longbench_v1"
|
||||
# Keep the dataset local like other accuracy tasks (avoid HF hub traffic).
|
||||
# Expected to be populated in CI image / test environment.
|
||||
DATASET_DIR = f"{llm_models_root()}/datasets/Xnhyacinth/LongBench"
|
||||
|
||||
# NOTE: LongBench v1 is driven by lm-evaluation-harness task configs.
|
||||
# We intentionally do not pin dataset_path here (it can be resolved by lm-eval
|
||||
# via HF Hub or local cache).
|
||||
ALPHA = 0.05
|
||||
BETA = 0.2
|
||||
SIGMA = 50.0
|
||||
|
||||
# Full sample
|
||||
NUM_SAMPLES = 4750
|
||||
|
||||
# These are used by AccuracyTask to construct SamplingParams defaults.
|
||||
# LongBench v1 tasks provide per-task gen_kwargs, so these are mainly a safe fallback.
|
||||
MAX_BATCH_SIZE = 256
|
||||
MAX_INPUT_LEN = 128000
|
||||
MAX_OUTPUT_LEN = 1024
|
||||
|
||||
EVALUATOR_CLS = tensorrt_llm.evaluate.LongBenchV1
|
||||
EVALUATOR_KWARGS = dict(dataset_path=DATASET_DIR,
|
||||
random_seed=0,
|
||||
apply_chat_template=True)
|
||||
|
||||
|
||||
class CliFlowAccuracyTestHarness:
|
||||
# Model
|
||||
MODEL_NAME = None
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
Qwen3/Qwen3-30B-A3B-Instruct-2507:
|
||||
# Skip Softmax Attention ref accuracy
|
||||
- extra_acc_spec: "target_sparsity=0.0"
|
||||
accuracy: 47.22
|
||||
- extra_acc_spec: "target_sparsity=0.5"
|
||||
accuracy: 47.22
|
||||
- extra_acc_spec: "target_sparsity=0.9"
|
||||
accuracy: 45.90
|
||||
@ -55,7 +55,7 @@ from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
|
||||
EagleDecodingConfig, KvCacheConfig, MoeConfig,
|
||||
MTPDecodingConfig, NGramDecodingConfig,
|
||||
RocketSparseAttentionConfig, SamplingParams,
|
||||
TorchCompileConfig)
|
||||
SkipSoftmaxAttentionConfig, TorchCompileConfig)
|
||||
from tensorrt_llm.quantization import QuantAlgo
|
||||
|
||||
from ..conftest import (get_device_count, get_device_memory, llm_models_root,
|
||||
@ -64,7 +64,7 @@ from ..conftest import (get_device_count, get_device_memory, llm_models_root,
|
||||
skip_pre_hopper, skip_ray)
|
||||
from .accuracy_core import (GSM8K, MMLU, CnnDailymail, GPQADiamond,
|
||||
JsonModeEval, LlmapiAccuracyTestHarness,
|
||||
LongBenchV2)
|
||||
LongBenchV1, LongBenchV2)
|
||||
|
||||
|
||||
def _get_default_torch_compile_config(torch_compile):
|
||||
@ -3816,6 +3816,46 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestQwen3_30B_A3B_Instruct_2507(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "Qwen3/Qwen3-30B-A3B-Instruct-2507"
|
||||
MODEL_PATH = f"{llm_models_root()}/{MODEL_NAME}"
|
||||
|
||||
@skip_pre_hopper
|
||||
# @pytest.mark.skip_less_device_memory(140000) # Only test for H200, B200
|
||||
@pytest.mark.parametrize(
|
||||
"target_sparsity,thr_prefill,thr_decode",
|
||||
[
|
||||
(0.0, 0.0, 0.0),
|
||||
(0.5, 85.97384174442398, 55.48258322852407),
|
||||
(0.9, 1418.142868970396, 863.147841750025),
|
||||
],
|
||||
ids=[
|
||||
"target_sparsity_0.0", "target_sparsity_0.5", "target_sparsity_0.9"
|
||||
],
|
||||
)
|
||||
def test_skip_softmax_attention(self, target_sparsity: float,
|
||||
thr_prefill: float, thr_decode: float):
|
||||
sparse_attention_config = SkipSoftmaxAttentionConfig(
|
||||
threshold_scale_factor={
|
||||
"prefill": thr_prefill,
|
||||
"decode": thr_decode,
|
||||
})
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.85)
|
||||
|
||||
if get_sm_version() >= 100:
|
||||
pytest.skip("Bug to be fixed on Blackwell")
|
||||
|
||||
with LLM(self.MODEL_PATH,
|
||||
attn_backend="TRTLLM",
|
||||
max_batch_size=256,
|
||||
max_num_tokens=100000,
|
||||
kv_cache_config=kv_cache_config,
|
||||
sparse_attention_config=sparse_attention_config) as llm:
|
||||
task = LongBenchV1(self.MODEL_NAME)
|
||||
task.evaluate(llm,
|
||||
extra_acc_spec=f"target_sparsity={target_sparsity}")
|
||||
|
||||
|
||||
class TestPhi4MiniInstruct(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "microsoft/Phi-4-mini-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/Phi-4-mini-instruct"
|
||||
|
||||
@ -55,6 +55,9 @@ l0_b200:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-TRTLLM]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[mxfp8-latency-CUTLASS]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a16_mxfp4[latency-TRTLLM]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp1-cutlass]
|
||||
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0] # nvbugs 5300551
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
|
||||
|
||||
@ -76,6 +76,9 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency-torch_compile=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_dummy_load_format
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.0] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.5] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B_Instruct_2507::test_skip_softmax_attention[target_sparsity_0.9] TIMEOUT (90)
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=False]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=True-eagle3_one_model=True]
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3[enable_chunked_prefill=False-eagle3_one_model=True]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user