[test] Use LLM API for Nemotron-H correctness test (#5097)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
This commit is contained in:
tomeras91 2025-06-12 09:54:46 +03:00 committed by GitHub
parent 505678a286
commit 06d9f1e2f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,26 +1,10 @@
import torch
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
# isort: off
from tensorrt_llm._torch.models.modeling_nemotron_h import (NemotronHConfig,
NemotronHForCausalLM
)
# isort: on
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from utils.llm_data import llm_models_root
from utils.util import skip_gpu_memory_less_than
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.model_engine import load_weights
from tensorrt_llm._torch.pyexecutor.resource_manager import \
MambaHybridCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig as KvCacheConfigCpp
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.sampling_params import SamplingParams
@ -31,236 +15,43 @@ def get_logprobs(token_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor:
return torch.log(token_probs)
def _generate(
model: NemotronHForCausalLM, tokenizer: PreTrainedTokenizerBase,
cache: MambaHybridCacheManager, text_prompts: list[str],
tokens_to_generate: int, device: torch.device
) -> tuple[list[int], list[list[int]], list[list[float]]]:
num_seqs = len(text_prompts)
all_token_ids = [
tokenizer.encode(prompt, add_special_tokens=False)
for prompt in text_prompts
]
input_ids = torch.cat([
torch.tensor(token_ids, dtype=torch.int64, device=device)
for token_ids in all_token_ids
],
dim=0)
request_ids = list(range(1, num_seqs + 1))
prompt_lens = [len(token_ids) for token_ids in all_token_ids]
requests = cache.add_dummy_requests(request_ids, prompt_lens)
cache.prepare_mamba_cache_blocks(request_ids)
metadata_cls = get_attention_backend(
model.model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(prompt_lens, dtype=torch.int),
num_contexts=num_seqs,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=[0] * num_seqs,
),
max_num_requests=num_seqs,
max_num_tokens=8192,
kv_cache_manager=cache,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
# prefill
position_ids = [torch.arange(0, prompt_len) for prompt_len in prompt_lens]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = model.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
# compute logprobs from logits
all_logits = logits.split(prompt_lens, dim=0)
all_logprobs = [
get_logprobs(
torch.tensor(token_ids[1:], dtype=torch.int64, device=device),
this_logits[:-1]).tolist()
for token_ids, this_logits in zip(all_token_ids, all_logits)
]
if tokens_to_generate > 0:
# sample token greedily
sampled_tokens = torch.cat([
torch.argmax(this_logits[-1]).unsqueeze(0)
for this_logits in all_logits
],
dim=0)
for i in range(num_seqs):
all_token_ids[i].append(sampled_tokens[i].item())
all_logprobs[i].append(
get_logprobs(sampled_tokens[i].unsqueeze(0),
all_logits[i][-1:]).item())
# one token already generated at prefill
for i in range(tokens_to_generate - 1):
num_cached_tokens_per_seq = [
prompt_len + i + 1 for prompt_len in prompt_lens
]
position_ids = torch.tensor([num_cached_tokens_per_seq],
dtype=torch.int64,
device=device)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([1] * num_seqs, dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=num_seqs,
max_num_tokens=8192,
kv_cache_manager=cache,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
with torch.inference_mode():
attn_metadata.prepare()
logits = model.forward(input_ids=sampled_tokens,
position_ids=position_ids,
attn_metadata=attn_metadata)
# sample token greedily
sampled_tokens = torch.argmax(logits, dim=-1, keepdim=False)
for i in range(num_seqs):
all_token_ids[i].append(sampled_tokens[i].item())
all_logprobs[i].append(
get_logprobs(sampled_tokens[i].unsqueeze(0),
logits[i].unsqueeze(0)).item())
for req in requests:
cache.free_resources(req)
return prompt_lens, all_token_ids, all_logprobs
def extract_prefill_logprobs(result: RequestOutput) -> torch.Tensor:
token_ids = torch.tensor(result.prompt_token_ids[1:])
logits = result.context_logits[:-1, :]
return get_logprobs(token_ids, logits)
def generate(
model: NemotronHForCausalLM,
tokenizer: PreTrainedTokenizerBase,
cache: MambaHybridCacheManager,
text_prompts: list[str],
tokens_to_generate: int,
device: torch.device,
one_by_one: bool = False
) -> tuple[list[int], list[list[int]], list[list[float]]]:
"""
Generate `tokens_to_generate` tokens from the given prompts using the given model and cache.
Return the prompt_lens along with the prefill+generated tokens and their logprobs, minus the first token in the prompt.
"""
if one_by_one:
num_prompts = len(text_prompts)
prompt_lens, tokens, logprobs = [None] * num_prompts, [
None
] * num_prompts, [None] * num_prompts
for i in range(num_prompts):
p, t, l = _generate(model, tokenizer, cache, [text_prompts[i]],
tokens_to_generate, device)
prompt_lens[i], tokens[i], logprobs[i] = p[0], t[0], l[0]
return prompt_lens, tokens, logprobs
return _generate(model, tokenizer, cache, text_prompts, tokens_to_generate,
device)
def extract_decode_logprobs(result: RequestOutput,
gen_idx: int = 0) -> torch.Tensor:
token_ids = torch.tensor(result.outputs[gen_idx].token_ids)
logits = result.outputs[gen_idx].generation_logits
return get_logprobs(token_ids, logits)
@skip_gpu_memory_less_than(
(2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure
def test_nemotron_h_correctness():
# This test is close to memory limit on A30 (with 24GB), so empty cache first
torch.cuda.empty_cache()
model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K"
nemotron_h_config = NemotronHConfig.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
dtype = nemotron_h_config.torch_dtype
device = torch.device('cuda')
assert dtype == torch.bfloat16
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
model_config = ModelConfig(pretrained_config=nemotron_h_config)
nemotron_h = NemotronHForCausalLM(model_config).to(device)
mapping = Mapping(world_size=1, tp_size=1, rank=0)
weights = load_weights(model_dir)
nemotron_h.load_weights(weights)
text_prompts = [
"The future of AI is",
"The president of the United States is",
]
num_prompts = len(text_prompts)
num_blocks = 100
tokens_per_block = 128
head_dim = nemotron_h.config.hidden_size // nemotron_h.config.num_attention_heads
num_layers = nemotron_h.config.hybrid_override_pattern.count("*")
layer_mask = [
char == "*" for char in nemotron_h.config.hybrid_override_pattern
]
mamba_num_layers = nemotron_h.config.hybrid_override_pattern.count("M")
mamba_layer_mask = [
char == "M" for char in nemotron_h.config.hybrid_override_pattern
]
num_kv_heads = nemotron_h.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
max_batch_size = num_prompts
kv_cache_config = KvCacheConfigCpp(max_tokens=num_blocks * tokens_per_block,
enable_block_reuse=False)
kv_cache_manager = MambaHybridCacheManager(
# mamba cache parameters
nemotron_h.config.hidden_size,
nemotron_h.config.ssm_state_size,
nemotron_h.config.conv_kernel,
nemotron_h.config.expand,
nemotron_h.config.n_groups,
nemotron_h.config.mamba_head_dim,
mamba_num_layers,
mamba_layer_mask,
nemotron_h.config.torch_dtype,
# kv cache parameters
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
nemotron_h = LLM(
model=model_dir,
max_batch_size=num_prompts,
use_cuda_graph=False,
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
enable_trtllm_sampler=True,
)
prompt_lens, tokens_no_batching, logprobs_no_batching = generate(
model=nemotron_h,
tokenizer=tokenizer,
cache=kv_cache_manager,
text_prompts=text_prompts,
tokens_to_generate=9,
device=torch.device("cuda"),
one_by_one=True)
completions_no_batching = [
tokenizer.decode(tokens_no_batching[i][prompt_lens[i]:])
for i in range(num_prompts)
]
_, tokens_batching, logprobs_batching = generate(
model=nemotron_h,
tokenizer=tokenizer,
cache=kv_cache_manager,
text_prompts=text_prompts,
tokens_to_generate=9,
device=torch.device("cuda"))
completions_batching = [
tokenizer.decode(tokens_batching[i][prompt_lens[i]:])
for i in range(num_prompts)
expected_completions = [
" bright, with endless possibilities for innovation and growth",
" the head of state and head of government of",
]
# reference logprobs for first prompt from mcore for prompt minus first token
@ -270,14 +61,6 @@ def test_nemotron_h_correctness():
-2.316344738006592
])
# compare logprobs with mcore logprobs, check that the max error is less than 0.3
mcore_atol = 0.3
torch.testing.assert_close(torch.tensor(
logprobs_no_batching[0][:prompt_lens[0] - 1]),
prefill_logprobs_ref_mcore,
atol=mcore_atol,
rtol=0.0)
# reference logprobs from initial implementation (commit 5ce1102a02bd2938c0c8334138371f081f55fcc1 on single RTX 6000)
initial_impl_atol = 0.2
batching_atol = 0.2
@ -311,137 +94,133 @@ def test_nemotron_h_correctness():
decode_logprobs_ref_initial_no_batching = [
torch.tensor([
-2.2722280025482178, -0.5235245823860168, -0.8821321725845337,
-1.9436249732971191, -0.07366813719272614, -0.4224405586719513,
-0.3872227966785431, -0.06121065467596054, -1.0475994348526
-2.2722280025482178, -0.5124826431274414, -0.7916123270988464,
-2.1908130645751953, -0.059298671782016754, -0.5125972032546997,
-0.3856367766857147, -0.055953752249479294, -1.1059765815734863
]),
torch.tensor([
-1.329713225364685, -1.6879069805145264, -0.040034178644418716,
-0.4808207154273987, -0.3581068515777588, -0.2784178853034973,
-0.005814795847982168, -0.0563097707927227, -0.05941024422645569
-1.329713225364685, -1.5038213729858398, -0.021283088251948357,
-0.38457369804382324, -0.3582419157028198, -0.16527847945690155,
-0.0044861179776489735, -0.059462934732437134, -0.041099339723587036
])
]
decode_logprobs_ref_initial_with_batching = [
torch.tensor([
-2.2877156734466553, -0.507795512676239, -0.8313305377960205,
-1.940523386001587, -0.07369701564311981, -0.4190545976161957,
-0.4250463843345642, -0.061063338071107864, -1.046282410621643
-2.2877156734466553, -0.46699056029319763, -0.7909849286079407,
-2.1276988983154297, -0.062114741653203964, -0.5291495323181152,
-0.38685765862464905, -0.05595658719539642, -1.1020748615264893
]),
torch.tensor([
-1.3567769527435303, -1.7291667461395264, -0.04527968540787697,
-0.4836069345474243, -0.3971801698207855, -0.2481495887041092,
-0.005787517875432968, -0.056093256920576096, -0.058267030864953995
-1.3567769527435303, -1.5647790431976318, -0.022344056516885757,
-0.38503751158714294, -0.3581986725330353, -0.18398350477218628,
-0.004726295825093985, -0.05941498652100563, -0.04291720315814018
])
]
expected_completions = [
" bright, with endless possibilities for innovation and growth",
" the head of state and head of government of",
]
try:
sampling_params = SamplingParams(max_tokens=9,
temperature=0.0,
add_special_tokens=False,
return_context_logits=True,
return_generation_logits=True)
for i in range(num_prompts):
prefill_logprobs_no_batching = torch.tensor(
logprobs_no_batching[i][:prompt_lens[i] - 1])
decode_logprobs_no_batching = torch.tensor(
logprobs_no_batching[i][prompt_lens[i] - 1:])
results_no_batching = [
nemotron_h.generate(text_prompt, sampling_params)
for text_prompt in text_prompts
]
completions_no_batching = [
result.outputs[0].text for result in results_no_batching
]
prefill_logprobs_no_batching = [
extract_prefill_logprobs(result).cpu()
for result in results_no_batching
]
decode_logprobs_no_batching = [
extract_decode_logprobs(result).cpu()
for result in results_no_batching
]
prefill_logprobs_batching = torch.tensor(
logprobs_batching[i][:prompt_lens[i] - 1])
decode_logprobs_batching = torch.tensor(
logprobs_batching[i][prompt_lens[i] - 1:])
results_batching = nemotron_h.generate(text_prompts, sampling_params)
completions_batching = [
result.outputs[0].text for result in results_batching
]
prefill_logprobs_batching = [
extract_prefill_logprobs(result).cpu()
for result in results_batching
]
decode_logprobs_batching = [
extract_decode_logprobs(result).cpu() for result in results_batching
]
# compare prompt logprobs with initial implementation
torch.testing.assert_close(prefill_logprobs_no_batching,
prefill_logprobs_ref_initial_no_batching[i],
atol=initial_impl_atol,
rtol=0.0)
torch.testing.assert_close(
prefill_logprobs_batching,
prefill_logprobs_ref_initial_with_batching[i],
atol=initial_impl_atol,
rtol=0.0)
# compare expected completion
assert completions_batching[i] == expected_completions[i]
assert completions_no_batching[i] == expected_completions[i]
# compare decode logprobs with initial implementation
torch.testing.assert_close(decode_logprobs_no_batching,
decode_logprobs_ref_initial_no_batching[i],
atol=initial_impl_atol,
rtol=0.0)
torch.testing.assert_close(decode_logprobs_batching,
decode_logprobs_ref_initial_with_batching[i],
atol=initial_impl_atol,
rtol=0.0)
# compare logprobs with and without batching, tolerace by diff in initial implementation
torch.testing.assert_close(prefill_logprobs_batching,
prefill_logprobs_no_batching,
atol=batching_atol,
rtol=0.0)
torch.testing.assert_close(decode_logprobs_batching,
decode_logprobs_no_batching,
atol=batching_atol,
rtol=0.0)
# now let's test that decodes match prefill logprobs
text_prompts_with_completions = [
f"{text_prompts[i]}{completions_batching[i]}"
for i in range(num_prompts)
]
_, _, full_sequence_logprobs = generate(
model=nemotron_h,
tokenizer=tokenizer,
cache=kv_cache_manager,
text_prompts=text_prompts_with_completions,
tokens_to_generate=0,
device=torch.device("cuda"))
# compare full sequence logprobs with prefill+decode logprobs, tolerance like mcore tolerance
for i in range(num_prompts):
torch.testing.assert_close(torch.tensor(full_sequence_logprobs[i]),
torch.tensor(logprobs_batching[i]),
# compare logprobs with mcore logprobs, check that the max error is less than 0.3
mcore_atol = 0.3
torch.testing.assert_close(torch.tensor(
prefill_logprobs_no_batching[0]),
prefill_logprobs_ref_mcore,
atol=mcore_atol,
rtol=0.0)
kv_cache_manager.shutdown()
for i in range(num_prompts):
# compare prompt logprobs with initial implementation
torch.testing.assert_close(
prefill_logprobs_no_batching[i],
prefill_logprobs_ref_initial_no_batching[i],
atol=initial_impl_atol,
rtol=0.0)
torch.testing.assert_close(
prefill_logprobs_batching[i],
prefill_logprobs_ref_initial_with_batching[i],
atol=initial_impl_atol,
rtol=0.0)
# clear memory before next test
del nemotron_h
torch.cuda.empty_cache()
# compare expected completion
assert completions_batching[i] == expected_completions[i]
assert completions_no_batching[i] == expected_completions[i]
# compare decode logprobs with initial implementation
torch.testing.assert_close(
decode_logprobs_no_batching[i],
decode_logprobs_ref_initial_no_batching[i],
atol=initial_impl_atol,
rtol=0.0)
torch.testing.assert_close(
decode_logprobs_batching[i],
decode_logprobs_ref_initial_with_batching[i],
atol=initial_impl_atol,
rtol=0.0)
# TODO: once LLM API supports context and generation logits, use it in above test and remove this one
@skip_gpu_memory_less_than(
(2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure
def test_nemotron_h_llm_api():
model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K"
text_prompts = [
"The future of AI is",
"The president of the United States is",
]
num_prompts = len(text_prompts)
# compare logprobs with and without batching, tolerace by diff in initial implementation
torch.testing.assert_close(prefill_logprobs_batching[i],
prefill_logprobs_no_batching[i],
atol=batching_atol,
rtol=0.0)
torch.testing.assert_close(decode_logprobs_batching[i],
decode_logprobs_no_batching[i],
atol=batching_atol,
rtol=0.0)
nemotron_h = LLM(
model=model_dir,
use_cuda_graph=False,
max_batch_size=num_prompts,
kv_cache_config=KvCacheConfig(enable_block_reuse=False),
)
# now let's test that decodes match prefill logprobs
text_prompts_with_completions = [
f"{text_prompts[i]}{completions_batching[i]}"
for i in range(num_prompts)
]
expected_completions = [
" bright, with endless possibilities for innovation and growth",
" the head of state and head of government of",
]
sampling_params.max_tokens = 1
full_sequence_results = nemotron_h.generate(
text_prompts_with_completions, sampling_params)
full_sequence_logprobs = [
extract_prefill_logprobs(result).cpu()
for result in full_sequence_results
]
sampling_params = SamplingParams(max_tokens=9,
temperature=0.0,
add_special_tokens=False)
# compare full sequence logprobs with prefill+decode logprobs, tolerance like mcore tolerance
for i in range(num_prompts):
prefill_decode_logprobs = torch.cat(
[prefill_logprobs_batching[i], decode_logprobs_batching[i]])
torch.testing.assert_close(full_sequence_logprobs[i],
prefill_decode_logprobs,
atol=mcore_atol,
rtol=0.0)
try:
results = nemotron_h.generate(text_prompts, sampling_params)
for result, expected_completion in zip(results, expected_completions):
assert result.outputs[0].text == expected_completion
finally:
nemotron_h.shutdown()