mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5085][fix] Nemotron H correctness test (#4444)
* Replace sanity test for nemotron h with a correctness test * Add prefill+decode reference logprobs from initial implementation + batched forward test * Add testing that decode matches prefill - compare decode vs all prefilling the decoded tokens
This commit is contained in:
parent
21aff2e313
commit
7b09cd904d
@ -1,5 +1,4 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
@ -13,80 +12,190 @@ from tensorrt_llm._torch.models.modeling_nemotron_h import (NemotronHConfig,
|
||||
NemotronHForCausalLM
|
||||
)
|
||||
# isort: on
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import skip_gpu_memory_less_than
|
||||
|
||||
from tensorrt_llm._torch.pyexecutor.model_engine import load_weights
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import \
|
||||
MambaHybridCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
NEMOTRON_H_CONFIG = {
|
||||
"architectures": ["NemotronHForCausalLM"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attention_head_dim": 128,
|
||||
"bos_token_id": 1,
|
||||
"chunk_size": 256,
|
||||
"conv_kernel": 4,
|
||||
"eos_token_id": 2,
|
||||
"expand": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"hidden_size": 4096,
|
||||
"hybrid_override_pattern":
|
||||
"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 21504,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"mamba_head_dim": 64,
|
||||
"mamba_hidden_act": "silu",
|
||||
"mamba_num_heads": 128,
|
||||
"mamba_proj_bias": False,
|
||||
"max_position_embeddings": 8192,
|
||||
"mlp_bias": False,
|
||||
"mlp_hidden_act": "relu2",
|
||||
"model_type": "nemotron_h",
|
||||
"n_groups": 8,
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 52,
|
||||
"num_key_value_heads": 8,
|
||||
"num_logits_to_keep": 1,
|
||||
"pad_token_id": 0,
|
||||
"rescale_prenorm_residual": True,
|
||||
"residual_in_fp32": False,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"sliding_window": None,
|
||||
"ssm_state_size": 128,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"use_bias": False,
|
||||
"use_cache": True,
|
||||
"use_conv_bias": True,
|
||||
"use_mamba_kernels": True,
|
||||
"vocab_size": 131072
|
||||
}
|
||||
|
||||
def get_logprobs(token_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
|
||||
raw_probs = torch.softmax(logits, dim=-1)
|
||||
index = token_ids.unsqueeze(1).cuda()
|
||||
token_probs = torch.gather(raw_probs, dim=1, index=index).squeeze(-1)
|
||||
return torch.log(token_probs)
|
||||
|
||||
|
||||
def _generate(
|
||||
model: NemotronHForCausalLM, tokenizer: PreTrainedTokenizerBase,
|
||||
cache: MambaHybridCacheManager, text_prompts: list[str],
|
||||
tokens_to_generate: int, device: torch.device
|
||||
) -> tuple[list[int], list[list[int]], list[list[float]]]:
|
||||
num_seqs = len(text_prompts)
|
||||
all_token_ids = [
|
||||
tokenizer.encode(prompt, add_special_tokens=False)
|
||||
for prompt in text_prompts
|
||||
]
|
||||
input_ids = torch.cat([
|
||||
torch.tensor(token_ids, dtype=torch.int64, device=device)
|
||||
for token_ids in all_token_ids
|
||||
],
|
||||
dim=0)
|
||||
request_ids = list(range(1, num_seqs + 1))
|
||||
prompt_lens = [len(token_ids) for token_ids in all_token_ids]
|
||||
|
||||
requests = cache.add_dummy_requests(request_ids, prompt_lens)
|
||||
cache.prepare_mamba_cache_blocks(request_ids)
|
||||
|
||||
metadata_cls = get_attention_backend(
|
||||
model.model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(prompt_lens, dtype=torch.int),
|
||||
num_contexts=num_seqs,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=[0] * num_seqs,
|
||||
),
|
||||
max_num_requests=num_seqs,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=cache,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
# prefill
|
||||
position_ids = [torch.arange(0, prompt_len) for prompt_len in prompt_lens]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = model.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
|
||||
# compute logprobs from logits
|
||||
all_logits = logits.split(prompt_lens, dim=0)
|
||||
all_logprobs = [
|
||||
get_logprobs(
|
||||
torch.tensor(token_ids[1:], dtype=torch.int64, device=device),
|
||||
this_logits[:-1]).tolist()
|
||||
for token_ids, this_logits in zip(all_token_ids, all_logits)
|
||||
]
|
||||
|
||||
if tokens_to_generate > 0:
|
||||
# sample token greedily
|
||||
sampled_tokens = torch.cat([
|
||||
torch.argmax(this_logits[-1]).unsqueeze(0)
|
||||
for this_logits in all_logits
|
||||
],
|
||||
dim=0)
|
||||
for i in range(num_seqs):
|
||||
all_token_ids[i].append(sampled_tokens[i].item())
|
||||
all_logprobs[i].append(
|
||||
get_logprobs(sampled_tokens[i].unsqueeze(0),
|
||||
all_logits[i][-1:]).item())
|
||||
|
||||
# one token already generated at prefill
|
||||
for i in range(tokens_to_generate - 1):
|
||||
num_cached_tokens_per_seq = [
|
||||
prompt_len + i + 1 for prompt_len in prompt_lens
|
||||
]
|
||||
position_ids = torch.tensor([num_cached_tokens_per_seq],
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([1] * num_seqs, dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=num_seqs,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=cache,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = model.forward(input_ids=sampled_tokens,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
# sample token greedily
|
||||
sampled_tokens = torch.argmax(logits, dim=-1, keepdim=False)
|
||||
for i in range(num_seqs):
|
||||
all_token_ids[i].append(sampled_tokens[i].item())
|
||||
all_logprobs[i].append(
|
||||
get_logprobs(sampled_tokens[i].unsqueeze(0),
|
||||
logits[i].unsqueeze(0)).item())
|
||||
|
||||
for req in requests:
|
||||
cache.free_resources(req)
|
||||
|
||||
return prompt_lens, all_token_ids, all_logprobs
|
||||
|
||||
|
||||
def generate(
|
||||
model: NemotronHForCausalLM,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cache: MambaHybridCacheManager,
|
||||
text_prompts: list[str],
|
||||
tokens_to_generate: int,
|
||||
device: torch.device,
|
||||
one_by_one: bool = False
|
||||
) -> tuple[list[int], list[list[int]], list[list[float]]]:
|
||||
"""
|
||||
Generate `tokens_to_generate` tokens from the given prompts using the given model and cache.
|
||||
Return the prompt_lens along with the prefill+generated tokens and their logprobs, minus the first token in the prompt.
|
||||
"""
|
||||
if one_by_one:
|
||||
num_prompts = len(text_prompts)
|
||||
prompt_lens, tokens, logprobs = [None] * num_prompts, [
|
||||
None
|
||||
] * num_prompts, [None] * num_prompts
|
||||
for i in range(num_prompts):
|
||||
p, t, l = _generate(model, tokenizer, cache, [text_prompts[i]],
|
||||
tokens_to_generate, device)
|
||||
prompt_lens[i], tokens[i], logprobs[i] = p[0], t[0], l[0]
|
||||
return prompt_lens, tokens, logprobs
|
||||
return _generate(model, tokenizer, cache, text_prompts, tokens_to_generate,
|
||||
device)
|
||||
|
||||
|
||||
class TestNemotronH(unittest.TestCase):
|
||||
|
||||
def test_nemotron_h_sanity(self):
|
||||
config_dict = deepcopy(NEMOTRON_H_CONFIG)
|
||||
nemotron_h_config = NemotronHConfig.from_dict(config_dict)
|
||||
@skip_gpu_memory_less_than(
|
||||
(2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure
|
||||
def test_nemotron_correctness(self):
|
||||
model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K"
|
||||
nemotron_h_config = NemotronHConfig.from_pretrained(model_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
|
||||
dtype = nemotron_h_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
assert dtype == torch.bfloat16
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
|
||||
model_config = ModelConfig(pretrained_config=nemotron_h_config)
|
||||
nemotron_h = NemotronHForCausalLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
weights = load_weights(model_dir, mapping=mapping)
|
||||
nemotron_h.load_weights(weights)
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
sequence_lengths = context_sequence_lengths + [1, 1]
|
||||
past_seen_tokens = [0, 0, 0, 62, 75]
|
||||
request_ids = list(range(len(sequence_lengths)))
|
||||
token_nums = (torch.tensor(past_seen_tokens) +
|
||||
torch.tensor(sequence_lengths)).tolist()
|
||||
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
||||
text_prompts = [
|
||||
"The future of AI is",
|
||||
"The president of the United States is",
|
||||
]
|
||||
num_prompts = len(text_prompts)
|
||||
|
||||
num_blocks = 100
|
||||
tokens_per_block = 128
|
||||
@ -101,14 +210,8 @@ class TestNemotronH(unittest.TestCase):
|
||||
]
|
||||
num_kv_heads = nemotron_h.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = len(context_sequence_lengths) + 2
|
||||
max_batch_size = num_prompts
|
||||
|
||||
if dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block,
|
||||
enable_block_reuse=False)
|
||||
@ -132,53 +235,181 @@ class TestNemotronH(unittest.TestCase):
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
max_batch_size=max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
kv_cache_manager.prepare_mamba_cache_blocks(request_ids)
|
||||
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=len(context_sequence_lengths) + 2,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
prompt_lens, tokens_no_batching, logprobs_no_batching = generate(
|
||||
model=nemotron_h,
|
||||
tokenizer=tokenizer,
|
||||
cache=kv_cache_manager,
|
||||
text_prompts=text_prompts,
|
||||
tokens_to_generate=9,
|
||||
device=torch.device("cuda"),
|
||||
one_by_one=True)
|
||||
completions_no_batching = [
|
||||
tokenizer.decode(tokens_no_batching[i][prompt_lens[i]:])
|
||||
for i in range(num_prompts)
|
||||
]
|
||||
|
||||
position_ids = []
|
||||
for i, tokens in enumerate(past_seen_tokens):
|
||||
seq_len = context_sequence_lengths[i] if i < len(
|
||||
context_sequence_lengths) else 1
|
||||
position_id = torch.arange(tokens,
|
||||
tokens + seq_len,
|
||||
device=input_ids.device)
|
||||
position_ids.append(position_id)
|
||||
_, tokens_batching, logprobs_batching = generate(
|
||||
model=nemotron_h,
|
||||
tokenizer=tokenizer,
|
||||
cache=kv_cache_manager,
|
||||
text_prompts=text_prompts,
|
||||
tokens_to_generate=9,
|
||||
device=torch.device("cuda"))
|
||||
completions_batching = [
|
||||
tokenizer.decode(tokens_batching[i][prompt_lens[i]:])
|
||||
for i in range(num_prompts)
|
||||
]
|
||||
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0)
|
||||
# reference logprobs for first prompt from mcore for prompt minus first token
|
||||
# TODO(oargov): generate a reference on-the-fly once we have confidence in the HF impl
|
||||
prefill_logprobs_ref_mcore = torch.tensor([
|
||||
-7.415980815887451, -0.36192911863327026, -2.8658294677734375,
|
||||
-2.316344738006592
|
||||
])
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nemotron_h.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
# compare logprobs with mcore logprobs, check that the max error is less than 0.3
|
||||
mcore_atol = 0.3
|
||||
torch.testing.assert_close(torch.tensor(
|
||||
logprobs_no_batching[0][:prompt_lens[0] - 1]),
|
||||
prefill_logprobs_ref_mcore,
|
||||
atol=mcore_atol,
|
||||
rtol=0.0)
|
||||
|
||||
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
||||
# reference logprobs from initial implementation (commit 5ce1102a02bd2938c0c8334138371f081f55fcc1 on single RTX 6000)
|
||||
initial_impl_atol = 0.2
|
||||
batching_atol = 0.2
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = nemotron_h.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
||||
prefill_logprobs_ref_initial_no_batching = [
|
||||
torch.tensor([
|
||||
-7.4359540939331055,
|
||||
-0.37661877274513245,
|
||||
-2.8925108909606934,
|
||||
-2.268364906311035,
|
||||
]),
|
||||
torch.tensor([
|
||||
-8.759482383728027,
|
||||
-1.656238079071045,
|
||||
-0.5448741912841797,
|
||||
-1.7702054977416992,
|
||||
-0.05832016468048096,
|
||||
-1.460732102394104,
|
||||
])
|
||||
]
|
||||
prefill_logprobs_ref_initial_with_batching = [
|
||||
torch.tensor([
|
||||
-7.401950836181641, -0.38696032762527466, -2.8725428581237793,
|
||||
-2.2654521465301514
|
||||
]),
|
||||
torch.tensor([
|
||||
-8.73007583618164, -1.6853574514389038, -0.5468529462814331,
|
||||
-1.7846013307571411, -0.053610533475875854, -1.4385275840759277
|
||||
])
|
||||
]
|
||||
|
||||
decode_logprobs_ref_initial_no_batching = [
|
||||
torch.tensor([
|
||||
-2.2722280025482178, -0.5235245823860168, -0.8821321725845337,
|
||||
-1.9436249732971191, -0.07366813719272614, -0.4224405586719513,
|
||||
-0.3872227966785431, -0.06121065467596054, -1.0475994348526
|
||||
]),
|
||||
torch.tensor([
|
||||
-1.329713225364685, -1.6879069805145264, -0.040034178644418716,
|
||||
-0.4808207154273987, -0.3581068515777588, -0.2784178853034973,
|
||||
-0.005814795847982168, -0.0563097707927227, -0.05941024422645569
|
||||
])
|
||||
]
|
||||
decode_logprobs_ref_initial_with_batching = [
|
||||
torch.tensor([
|
||||
-2.2877156734466553, -0.507795512676239, -0.8313305377960205,
|
||||
-1.940523386001587, -0.07369701564311981, -0.4190545976161957,
|
||||
-0.4250463843345642, -0.061063338071107864, -1.046282410621643
|
||||
]),
|
||||
torch.tensor([
|
||||
-1.3567769527435303, -1.7291667461395264, -0.04527968540787697,
|
||||
-0.4836069345474243, -0.3971801698207855, -0.2481495887041092,
|
||||
-0.005787517875432968, -0.056093256920576096,
|
||||
-0.058267030864953995
|
||||
])
|
||||
]
|
||||
|
||||
expected_completions = [
|
||||
" bright, with endless possibilities for innovation and growth",
|
||||
" the head of state and head of government of",
|
||||
]
|
||||
|
||||
for i in range(num_prompts):
|
||||
prefill_logprobs_no_batching = torch.tensor(
|
||||
logprobs_no_batching[i][:prompt_lens[i] - 1])
|
||||
decode_logprobs_no_batching = torch.tensor(
|
||||
logprobs_no_batching[i][prompt_lens[i] - 1:])
|
||||
|
||||
prefill_logprobs_batching = torch.tensor(
|
||||
logprobs_batching[i][:prompt_lens[i] - 1])
|
||||
decode_logprobs_batching = torch.tensor(
|
||||
logprobs_batching[i][prompt_lens[i] - 1:])
|
||||
|
||||
# compare prompt logprobs with initial implementation
|
||||
torch.testing.assert_close(
|
||||
prefill_logprobs_no_batching,
|
||||
prefill_logprobs_ref_initial_no_batching[i],
|
||||
atol=initial_impl_atol,
|
||||
rtol=0.0)
|
||||
torch.testing.assert_close(
|
||||
prefill_logprobs_batching,
|
||||
prefill_logprobs_ref_initial_with_batching[i],
|
||||
atol=initial_impl_atol,
|
||||
rtol=0.0)
|
||||
|
||||
# compare expected completion
|
||||
self.assertEqual(completions_batching[i], expected_completions[i])
|
||||
self.assertEqual(completions_no_batching[i],
|
||||
expected_completions[i])
|
||||
|
||||
# compare decode logprobs with initial implementation
|
||||
torch.testing.assert_close(
|
||||
decode_logprobs_no_batching,
|
||||
decode_logprobs_ref_initial_no_batching[i],
|
||||
atol=initial_impl_atol,
|
||||
rtol=0.0)
|
||||
torch.testing.assert_close(
|
||||
decode_logprobs_batching,
|
||||
decode_logprobs_ref_initial_with_batching[i],
|
||||
atol=initial_impl_atol,
|
||||
rtol=0.0)
|
||||
|
||||
# compare logprobs with and without batching, tolerace by diff in initial implementation
|
||||
torch.testing.assert_close(prefill_logprobs_batching,
|
||||
prefill_logprobs_no_batching,
|
||||
atol=batching_atol,
|
||||
rtol=0.0)
|
||||
torch.testing.assert_close(decode_logprobs_batching,
|
||||
decode_logprobs_no_batching,
|
||||
atol=batching_atol,
|
||||
rtol=0.0)
|
||||
|
||||
# now let's test that decodes match prefill logprobs
|
||||
text_prompts_with_completions = [
|
||||
f"{text_prompts[i]}{completions_batching[i]}"
|
||||
for i in range(num_prompts)
|
||||
]
|
||||
_, _, full_sequence_logprobs = generate(
|
||||
model=nemotron_h,
|
||||
tokenizer=tokenizer,
|
||||
cache=kv_cache_manager,
|
||||
text_prompts=text_prompts_with_completions,
|
||||
tokens_to_generate=0,
|
||||
device=torch.device("cuda"))
|
||||
|
||||
# compare full sequence logprobs with prefill+decode logprobs, tolerance like mcore tolerance
|
||||
for i in range(num_prompts):
|
||||
torch.testing.assert_close(torch.tensor(full_sequence_logprobs[i]),
|
||||
torch.tensor(logprobs_batching[i]),
|
||||
atol=mcore_atol,
|
||||
rtol=0.0)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user