mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: add EXAONE4 accuracy test (#6397)
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
This commit is contained in:
parent
df90202b51
commit
ee6ab5be96
@ -3,7 +3,6 @@ from typing import Optional, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm._torch.distributed import AllReduceParams
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -55,7 +54,6 @@ class Exaone4Attention(Attention):
|
||||
|
||||
def __init__(self,
|
||||
model_config: ModelConfig[Exaone4Config],
|
||||
is_sliding: bool,
|
||||
layer_idx: Optional[int] = None,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
fuse_qk_norm_rope: bool = False):
|
||||
@ -64,9 +62,10 @@ class Exaone4Attention(Attention):
|
||||
self.attention_window_size = None
|
||||
|
||||
# NOTE: In EXAONE4, only sliding layers apply rope.
|
||||
self.is_sliding = is_sliding
|
||||
self.sliding_window = config.sliding_window
|
||||
self.is_sliding = check_is_sliding(config, layer_idx)
|
||||
pos_embd_params = None
|
||||
if self.is_sliding:
|
||||
if self.sliding_window is None or self.is_sliding:
|
||||
self.attention_window_size = config.sliding_window
|
||||
|
||||
pos_embd_params = PositionalEmbeddingParams(
|
||||
@ -140,7 +139,7 @@ class Exaone4Attention(Attention):
|
||||
|
||||
q, k, v = self.split_qkv(q, k, v)
|
||||
q, k = self.apply_qk_norm(q, k)
|
||||
if self.is_sliding:
|
||||
if self.sliding_window is None or self.is_sliding:
|
||||
return super().apply_rope(q, k, v, position_ids)
|
||||
else:
|
||||
return q, k, v
|
||||
@ -152,7 +151,6 @@ class Exaone4Attention(Attention):
|
||||
attn_metadata: AttentionMetadata,
|
||||
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
|
||||
CAUSAL,
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@ -165,7 +163,6 @@ class Exaone4Attention(Attention):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=attention_mask,
|
||||
all_reduce_params=all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
attention_window_size=self.attention_window_size,
|
||||
**kwargs,
|
||||
@ -185,11 +182,9 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
self.layer_idx = layer_idx
|
||||
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
|
||||
)
|
||||
is_sliding = check_is_sliding(config, layer_idx)
|
||||
|
||||
self.self_attn = Exaone4Attention(
|
||||
model_config,
|
||||
is_sliding=is_sliding,
|
||||
layer_idx=layer_idx,
|
||||
aux_stream=aux_stream,
|
||||
)
|
||||
@ -228,8 +223,6 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.mapping.tp_size == 1)),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -237,11 +230,7 @@ class Exaone4DecoderLayer(DecoderLayer):
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.mlp(
|
||||
hidden_states,
|
||||
final_all_reduce_params=AllReduceParams(
|
||||
enable_allreduce=not (self.mapping.tp_size == 1)),
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
@ -153,3 +153,5 @@ microsoft/Phi-4-multimodal-instruct-long-rope:
|
||||
- accuracy: 75.85
|
||||
microsoft/Phi-4-mini-instruct:
|
||||
- accuracy: 82.30
|
||||
LGAI-EXAONE/EXAONE-4.0-32B:
|
||||
- accuracy: 88.36
|
||||
|
||||
@ -239,3 +239,5 @@ microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 69.69
|
||||
microsoft/Phi-4-multimodal-instruct-long-rope:
|
||||
- accuracy: 65.98
|
||||
LGAI-EXAONE/EXAONE-4.0-32B:
|
||||
- accuracy: 78.52
|
||||
|
||||
@ -2204,3 +2204,19 @@ class TestPhi4MM(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestEXAONE4(LlmapiAccuracyTestHarness):
|
||||
MODEL_NAME = "LGAI-EXAONE/EXAONE-4.0-32B"
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
max_attention_window=[4096, 4096, 4096, 131072])
|
||||
|
||||
def test_auto_dtype(self):
|
||||
model_path = f"{llm_models_root()}/EXAONE-4.0-32B"
|
||||
with LLM(model_path, kv_cache_config=self.kv_cache_config) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -513,6 +513,7 @@ accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype_long_rope
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MiniInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestEXAONE4::test_auto_dtype
|
||||
|
||||
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
|
||||
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
|
||||
|
||||
@ -18,6 +18,7 @@ l0_a30:
|
||||
- unittest/_torch/modeling -k "modeling_phi3"
|
||||
- unittest/_torch/modeling -k "modeling_qwen"
|
||||
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
||||
- unittest/_torch/modeling -k "modeling_exaone4"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- unittest/_torch/test_beam_search.py
|
||||
- condition:
|
||||
|
||||
@ -46,6 +46,7 @@ l0_h100:
|
||||
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_eagle3
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
|
||||
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
|
||||
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
|
||||
|
||||
384
tests/unittest/_torch/modeling/test_modeling_exaone4.py
Normal file
384
tests/unittest/_torch/modeling/test_modeling_exaone4.py
Normal file
@ -0,0 +1,384 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
try:
|
||||
from transformers import Exaone4Config
|
||||
except ImportError:
|
||||
# TODO: Remove this once we have a proper transformers package
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class Exaone4Config(PretrainedConfig):
|
||||
model_type = "exaone4"
|
||||
|
||||
|
||||
SKIP_EXAONE4_HF_ACCURACY_TEST = False
|
||||
try:
|
||||
from transformers import Exaone4ForCausalLM as HFExaone4ForCausalLM
|
||||
except ImportError:
|
||||
# TODO: Remove this once we have a proper config for Exaone4
|
||||
SKIP_EXAONE4_HF_ACCURACY_TEST = True
|
||||
|
||||
from transformers.cache_utils import HybridCache
|
||||
from utils.util import getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
WINDOW_SIZE = 4
|
||||
EXAONE4_SINGLE_LAYER_CONFIG = {
|
||||
"architectures": ["Exaone4ForCausalLM"],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 361,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 5120,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 27392,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "exaone4",
|
||||
"num_attention_heads": 40,
|
||||
"num_hidden_layers":
|
||||
4, #NOTE: For testing, we use 4 instead of 64(all layers)
|
||||
"num_key_value_heads": 8,
|
||||
"pad_token_id": 0,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"factor": 16.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"original_max_position_embeddings": 8192,
|
||||
"rope_type": "llama3"
|
||||
},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 4, # NOTE: For testing, we use 4 instead of 4096
|
||||
"sliding_window_pattern": "LLLG",
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.54.0.dev0",
|
||||
"use_cache": True,
|
||||
"vocab_size": 102400,
|
||||
"attn_implementation": "flash_attention_2"
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
input_len: int = WINDOW_SIZE - 1
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}-input_len:{self.input_len}-use_cuda_graph:{self.use_cuda_graph}"
|
||||
|
||||
|
||||
class TestEXAONE4(unittest.TestCase):
|
||||
|
||||
@parameterized.expand([None, "FP8"])
|
||||
def test_exaone4_sanity(self, quant_algo):
|
||||
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
|
||||
# TODO: Change to PretrainedConfig if we don't have the transformers version
|
||||
exaone4_config = Exaone4Config.from_dict(config_dict)
|
||||
if quant_algo:
|
||||
quant_config = QuantConfig(quant_algo=quant_algo)
|
||||
else:
|
||||
quant_config = None
|
||||
if quant_algo == "FP8" and getSMVersion() < 89:
|
||||
self.skipTest("This test is not supported in pre-Ada architecture")
|
||||
|
||||
dtype = exaone4_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=exaone4_config,
|
||||
quant_config=quant_config)
|
||||
exaone4 = Exaone4ForCausalLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
sequence_lengths = context_sequence_lengths + [1, 1]
|
||||
past_seen_tokens = [0, 0, 0, 62, 75]
|
||||
request_ids = list(range(len(sequence_lengths)))
|
||||
token_nums = (torch.tensor(past_seen_tokens) +
|
||||
torch.tensor(sequence_lengths)).tolist()
|
||||
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
||||
|
||||
num_blocks = 100
|
||||
tokens_per_block = 128
|
||||
head_dim = exaone4.config.hidden_size // exaone4.config.num_attention_heads
|
||||
num_layers = exaone4.config.num_hidden_layers
|
||||
num_kv_heads = exaone4.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = len(context_sequence_lengths) + 2
|
||||
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=len(context_sequence_lengths) + 2,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
position_ids = []
|
||||
for i, tokens in enumerate(past_seen_tokens):
|
||||
seq_len = context_sequence_lengths[i] if i < len(
|
||||
context_sequence_lengths) else 1
|
||||
position_id = torch.arange(tokens,
|
||||
tokens + seq_len,
|
||||
device=input_ids.device)
|
||||
position_ids.append(position_id)
|
||||
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = exaone4.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = exaone4.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@parameterized.expand([
|
||||
Scenario(backend="TRTLLM", input_len=WINDOW_SIZE - 2),
|
||||
Scenario(
|
||||
backend="TRTLLM", input_len=WINDOW_SIZE - 2, use_cuda_graph=True),
|
||||
], lambda testcase_func, param_num, param:
|
||||
f"{testcase_func.__name__}[{param.args[0]}]")
|
||||
@torch.no_grad()
|
||||
def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None:
|
||||
"""
|
||||
Compare output to HF
|
||||
"""
|
||||
# TODO: Remove this once we have a proper transformers version for Exaone4
|
||||
if SKIP_EXAONE4_HF_ACCURACY_TEST:
|
||||
self.skipTest("Exaone4 is not supported in this environment")
|
||||
|
||||
backend = scenario.backend
|
||||
metadata_cls = get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
|
||||
exaone4_config = Exaone4Config.from_dict(config_dict)
|
||||
dtype = exaone4_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
# TODO: Or change to PreTrainedModel
|
||||
hf_exaone4 = HFExaone4ForCausalLM(exaone4_config).to(dtype).to(
|
||||
device).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=exaone4_config,
|
||||
attn_backend=backend)
|
||||
exaone4 = Exaone4ForCausalLM(model_config).to(dtype).to(device)
|
||||
exaone4.load_weights(hf_exaone4.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
head_dim = getattr(
|
||||
exaone4.config, "head_dim",
|
||||
exaone4.config.hidden_size // exaone4.config.num_attention_heads)
|
||||
num_layers = exaone4.config.num_hidden_layers
|
||||
num_kv_heads = exaone4.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
hf_cache = HybridCache(config=exaone4_config,
|
||||
max_batch_size=batch_size,
|
||||
max_cache_len=max_seq_len,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=False,
|
||||
enable_partial_reuse=False,
|
||||
copy_on_partial_reuse=False,
|
||||
max_attention_window=[int(exaone4_config.sliding_window)],
|
||||
max_tokens=num_blocks * tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# context
|
||||
input_ids = torch.tensor(
|
||||
[i * 100 for i in range(1, scenario.input_len + 1)],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
num_cached_tokens_per_seq = [0]
|
||||
request_ids = [1]
|
||||
token_nums = [input_ids.size(-1)]
|
||||
prompt_lens = [input_ids.size(-1)]
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=1,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
# Note: no CUDA graphs for prefill, the graph runner is built for
|
||||
# decoding only.
|
||||
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = exaone4.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_exaone4.forward(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
past_key_values=hf_cache,
|
||||
use_cache=True)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
# gen
|
||||
gen_input_ids = torch.tensor([600], dtype=torch.int32, device=device)
|
||||
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
gen_position_ids = [
|
||||
torch.arange(input_ids.size(-1),
|
||||
input_ids.size(-1) + gen_input_ids.size(-1),
|
||||
dtype=torch.int32)
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
return exaone4.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: exaone4.forward(**inputs))
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = run_forward(input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_exaone4.forward(
|
||||
input_ids=gen_input_ids.unsqueeze(0), #hf_gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True,
|
||||
cache_position=torch.LongTensor([input_ids.size(-1)
|
||||
]).to(device),
|
||||
last_cache_position=input_ids.size(-1) + 1)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
Loading…
Reference in New Issue
Block a user