From 7b09cd904d343de77f42301a3e7871758f3b77a2 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Tue, 20 May 2025 12:55:25 +0300 Subject: [PATCH] [TRTLLM-5085][fix] Nemotron H correctness test (#4444) * Replace sanity test for nemotron h with a correctness test * Add prefill+decode reference logprobs from initial implementation + batched forward test * Add testing that decode matches prefill - compare decode vs all prefilling the decoded tokens --- .../modeling/test_modeling_nemotron_h.py | 439 +++++++++++++----- 1 file changed, 335 insertions(+), 104 deletions(-) diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index d7c689fccf..b44ceb219a 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -1,5 +1,4 @@ import unittest -from copy import deepcopy import torch @@ -13,80 +12,190 @@ from tensorrt_llm._torch.models.modeling_nemotron_h import (NemotronHConfig, NemotronHForCausalLM ) # isort: on +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from utils.llm_data import llm_models_root +from utils.util import skip_gpu_memory_less_than + +from tensorrt_llm._torch.pyexecutor.model_engine import load_weights from tensorrt_llm._torch.pyexecutor.resource_manager import \ MambaHybridCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping -NEMOTRON_H_CONFIG = { - "architectures": ["NemotronHForCausalLM"], - "attention_bias": False, - "attention_dropout": 0.0, - "attention_head_dim": 128, - "bos_token_id": 1, - "chunk_size": 256, - "conv_kernel": 4, - "eos_token_id": 2, - "expand": 2, - "hidden_dropout": 0.0, - "hidden_size": 4096, - "hybrid_override_pattern": - "M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", - "initializer_range": 0.02, - "intermediate_size": 21504, - "layer_norm_epsilon": 1e-05, - "mamba_head_dim": 64, - "mamba_hidden_act": "silu", - "mamba_num_heads": 128, - "mamba_proj_bias": False, - "max_position_embeddings": 8192, - "mlp_bias": False, - "mlp_hidden_act": "relu2", - "model_type": "nemotron_h", - "n_groups": 8, - "num_attention_heads": 32, - "num_hidden_layers": 52, - "num_key_value_heads": 8, - "num_logits_to_keep": 1, - "pad_token_id": 0, - "rescale_prenorm_residual": True, - "residual_in_fp32": False, - "rms_norm_eps": 1e-05, - "sliding_window": None, - "ssm_state_size": 128, - "tie_word_embeddings": False, - "torch_dtype": "bfloat16", - "use_bias": False, - "use_cache": True, - "use_conv_bias": True, - "use_mamba_kernels": True, - "vocab_size": 131072 -} + +def get_logprobs(token_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: + raw_probs = torch.softmax(logits, dim=-1) + index = token_ids.unsqueeze(1).cuda() + token_probs = torch.gather(raw_probs, dim=1, index=index).squeeze(-1) + return torch.log(token_probs) + + +def _generate( + model: NemotronHForCausalLM, tokenizer: PreTrainedTokenizerBase, + cache: MambaHybridCacheManager, text_prompts: list[str], + tokens_to_generate: int, device: torch.device +) -> tuple[list[int], list[list[int]], list[list[float]]]: + num_seqs = len(text_prompts) + all_token_ids = [ + tokenizer.encode(prompt, add_special_tokens=False) + for prompt in text_prompts + ] + input_ids = torch.cat([ + torch.tensor(token_ids, dtype=torch.int64, device=device) + for token_ids in all_token_ids + ], + dim=0) + request_ids = list(range(1, num_seqs + 1)) + prompt_lens = [len(token_ids) for token_ids in all_token_ids] + + requests = cache.add_dummy_requests(request_ids, prompt_lens) + cache.prepare_mamba_cache_blocks(request_ids) + + metadata_cls = get_attention_backend( + model.model_config.attn_backend).Metadata + attn_metadata = metadata_cls( + seq_lens=torch.tensor(prompt_lens, dtype=torch.int), + num_contexts=num_seqs, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=[0] * num_seqs, + ), + max_num_requests=num_seqs, + max_num_tokens=8192, + kv_cache_manager=cache, + request_ids=request_ids, + prompt_lens=prompt_lens, + ) + + # prefill + position_ids = [torch.arange(0, prompt_len) for prompt_len in prompt_lens] + position_ids = torch.cat(position_ids).unsqueeze(0).cuda() + with torch.inference_mode(): + attn_metadata.prepare() + logits = model.forward(input_ids=input_ids, + position_ids=position_ids, + attn_metadata=attn_metadata, + return_context_logits=True) + + # compute logprobs from logits + all_logits = logits.split(prompt_lens, dim=0) + all_logprobs = [ + get_logprobs( + torch.tensor(token_ids[1:], dtype=torch.int64, device=device), + this_logits[:-1]).tolist() + for token_ids, this_logits in zip(all_token_ids, all_logits) + ] + + if tokens_to_generate > 0: + # sample token greedily + sampled_tokens = torch.cat([ + torch.argmax(this_logits[-1]).unsqueeze(0) + for this_logits in all_logits + ], + dim=0) + for i in range(num_seqs): + all_token_ids[i].append(sampled_tokens[i].item()) + all_logprobs[i].append( + get_logprobs(sampled_tokens[i].unsqueeze(0), + all_logits[i][-1:]).item()) + + # one token already generated at prefill + for i in range(tokens_to_generate - 1): + num_cached_tokens_per_seq = [ + prompt_len + i + 1 for prompt_len in prompt_lens + ] + position_ids = torch.tensor([num_cached_tokens_per_seq], + dtype=torch.int64, + device=device) + + attn_metadata = metadata_cls( + seq_lens=torch.tensor([1] * num_seqs, dtype=torch.int), + num_contexts=0, + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=num_cached_tokens_per_seq, + ), + max_num_requests=num_seqs, + max_num_tokens=8192, + kv_cache_manager=cache, + request_ids=request_ids, + prompt_lens=prompt_lens, + ) + + with torch.inference_mode(): + attn_metadata.prepare() + logits = model.forward(input_ids=sampled_tokens, + position_ids=position_ids, + attn_metadata=attn_metadata) + + # sample token greedily + sampled_tokens = torch.argmax(logits, dim=-1, keepdim=False) + for i in range(num_seqs): + all_token_ids[i].append(sampled_tokens[i].item()) + all_logprobs[i].append( + get_logprobs(sampled_tokens[i].unsqueeze(0), + logits[i].unsqueeze(0)).item()) + + for req in requests: + cache.free_resources(req) + + return prompt_lens, all_token_ids, all_logprobs + + +def generate( + model: NemotronHForCausalLM, + tokenizer: PreTrainedTokenizerBase, + cache: MambaHybridCacheManager, + text_prompts: list[str], + tokens_to_generate: int, + device: torch.device, + one_by_one: bool = False +) -> tuple[list[int], list[list[int]], list[list[float]]]: + """ + Generate `tokens_to_generate` tokens from the given prompts using the given model and cache. + Return the prompt_lens along with the prefill+generated tokens and their logprobs, minus the first token in the prompt. + """ + if one_by_one: + num_prompts = len(text_prompts) + prompt_lens, tokens, logprobs = [None] * num_prompts, [ + None + ] * num_prompts, [None] * num_prompts + for i in range(num_prompts): + p, t, l = _generate(model, tokenizer, cache, [text_prompts[i]], + tokens_to_generate, device) + prompt_lens[i], tokens[i], logprobs[i] = p[0], t[0], l[0] + return prompt_lens, tokens, logprobs + return _generate(model, tokenizer, cache, text_prompts, tokens_to_generate, + device) class TestNemotronH(unittest.TestCase): - def test_nemotron_h_sanity(self): - config_dict = deepcopy(NEMOTRON_H_CONFIG) - nemotron_h_config = NemotronHConfig.from_dict(config_dict) + @skip_gpu_memory_less_than( + (2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure + def test_nemotron_correctness(self): + model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K" + nemotron_h_config = NemotronHConfig.from_pretrained(model_dir) + + tokenizer = AutoTokenizer.from_pretrained(model_dir) dtype = nemotron_h_config.torch_dtype device = torch.device('cuda') + assert dtype == torch.bfloat16 + kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 model_config = ModelConfig(pretrained_config=nemotron_h_config) nemotron_h = NemotronHForCausalLM(model_config).to(device) - input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500], - dtype=torch.int, - device=device) + mapping = Mapping(world_size=1, tp_size=1, rank=0) + weights = load_weights(model_dir, mapping=mapping) + nemotron_h.load_weights(weights) - context_sequence_lengths = [3, 2, 1] - sequence_lengths = context_sequence_lengths + [1, 1] - past_seen_tokens = [0, 0, 0, 62, 75] - request_ids = list(range(len(sequence_lengths))) - token_nums = (torch.tensor(past_seen_tokens) + - torch.tensor(sequence_lengths)).tolist() - prompt_lens = token_nums[:3] + past_seen_tokens[3:] + text_prompts = [ + "The future of AI is", + "The president of the United States is", + ] + num_prompts = len(text_prompts) num_blocks = 100 tokens_per_block = 128 @@ -101,14 +210,8 @@ class TestNemotronH(unittest.TestCase): ] num_kv_heads = nemotron_h.config.num_key_value_heads max_seq_len = num_blocks * tokens_per_block - batch_size = len(context_sequence_lengths) + 2 + max_batch_size = num_prompts - if dtype == torch.bfloat16: - kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16 - else: - raise ValueError("Invalid dtype") - - mapping = Mapping(world_size=1, tp_size=1, rank=0) kv_cache_config = KvCacheConfig(max_tokens=num_blocks * tokens_per_block, enable_block_reuse=False) @@ -132,53 +235,181 @@ class TestNemotronH(unittest.TestCase): head_dim=head_dim, tokens_per_block=tokens_per_block, max_seq_len=max_seq_len, - max_batch_size=batch_size, + max_batch_size=max_batch_size, mapping=mapping, dtype=kv_cache_dtype, ) - kv_cache_manager.add_dummy_requests(request_ids, token_nums) - kv_cache_manager.prepare_mamba_cache_blocks(request_ids) - metadata_cls = get_attention_backend(model_config.attn_backend).Metadata - attn_metadata = metadata_cls( - seq_lens=torch.tensor(sequence_lengths, dtype=torch.int), - num_contexts=len(context_sequence_lengths), - kv_cache_params=KVCacheParams( - use_cache=True, - num_cached_tokens_per_seq=past_seen_tokens, - ), - kv_cache_manager=kv_cache_manager, - request_ids=request_ids, - prompt_lens=prompt_lens, - max_num_requests=len(context_sequence_lengths) + 2, - max_num_tokens=8192, - ) + prompt_lens, tokens_no_batching, logprobs_no_batching = generate( + model=nemotron_h, + tokenizer=tokenizer, + cache=kv_cache_manager, + text_prompts=text_prompts, + tokens_to_generate=9, + device=torch.device("cuda"), + one_by_one=True) + completions_no_batching = [ + tokenizer.decode(tokens_no_batching[i][prompt_lens[i]:]) + for i in range(num_prompts) + ] - position_ids = [] - for i, tokens in enumerate(past_seen_tokens): - seq_len = context_sequence_lengths[i] if i < len( - context_sequence_lengths) else 1 - position_id = torch.arange(tokens, - tokens + seq_len, - device=input_ids.device) - position_ids.append(position_id) + _, tokens_batching, logprobs_batching = generate( + model=nemotron_h, + tokenizer=tokenizer, + cache=kv_cache_manager, + text_prompts=text_prompts, + tokens_to_generate=9, + device=torch.device("cuda")) + completions_batching = [ + tokenizer.decode(tokens_batching[i][prompt_lens[i]:]) + for i in range(num_prompts) + ] - position_ids = torch.cat(position_ids).unsqueeze(0) + # reference logprobs for first prompt from mcore for prompt minus first token + # TODO(oargov): generate a reference on-the-fly once we have confidence in the HF impl + prefill_logprobs_ref_mcore = torch.tensor([ + -7.415980815887451, -0.36192911863327026, -2.8658294677734375, + -2.316344738006592 + ]) - with torch.inference_mode(): - attn_metadata.prepare() - logits = nemotron_h.forward(input_ids=input_ids, - position_ids=position_ids, - attn_metadata=attn_metadata) + # compare logprobs with mcore logprobs, check that the max error is less than 0.3 + mcore_atol = 0.3 + torch.testing.assert_close(torch.tensor( + logprobs_no_batching[0][:prompt_lens[0] - 1]), + prefill_logprobs_ref_mcore, + atol=mcore_atol, + rtol=0.0) - self.assertEqual(len(past_seen_tokens), logits.shape[0]) + # reference logprobs from initial implementation (commit 5ce1102a02bd2938c0c8334138371f081f55fcc1 on single RTX 6000) + initial_impl_atol = 0.2 + batching_atol = 0.2 - with torch.inference_mode(): - attn_metadata.prepare() - logits = nemotron_h.forward(input_ids=input_ids, - position_ids=position_ids, - attn_metadata=attn_metadata, - return_context_logits=True) - self.assertEqual(input_ids.shape, logits.shape[:-1]) + prefill_logprobs_ref_initial_no_batching = [ + torch.tensor([ + -7.4359540939331055, + -0.37661877274513245, + -2.8925108909606934, + -2.268364906311035, + ]), + torch.tensor([ + -8.759482383728027, + -1.656238079071045, + -0.5448741912841797, + -1.7702054977416992, + -0.05832016468048096, + -1.460732102394104, + ]) + ] + prefill_logprobs_ref_initial_with_batching = [ + torch.tensor([ + -7.401950836181641, -0.38696032762527466, -2.8725428581237793, + -2.2654521465301514 + ]), + torch.tensor([ + -8.73007583618164, -1.6853574514389038, -0.5468529462814331, + -1.7846013307571411, -0.053610533475875854, -1.4385275840759277 + ]) + ] + + decode_logprobs_ref_initial_no_batching = [ + torch.tensor([ + -2.2722280025482178, -0.5235245823860168, -0.8821321725845337, + -1.9436249732971191, -0.07366813719272614, -0.4224405586719513, + -0.3872227966785431, -0.06121065467596054, -1.0475994348526 + ]), + torch.tensor([ + -1.329713225364685, -1.6879069805145264, -0.040034178644418716, + -0.4808207154273987, -0.3581068515777588, -0.2784178853034973, + -0.005814795847982168, -0.0563097707927227, -0.05941024422645569 + ]) + ] + decode_logprobs_ref_initial_with_batching = [ + torch.tensor([ + -2.2877156734466553, -0.507795512676239, -0.8313305377960205, + -1.940523386001587, -0.07369701564311981, -0.4190545976161957, + -0.4250463843345642, -0.061063338071107864, -1.046282410621643 + ]), + torch.tensor([ + -1.3567769527435303, -1.7291667461395264, -0.04527968540787697, + -0.4836069345474243, -0.3971801698207855, -0.2481495887041092, + -0.005787517875432968, -0.056093256920576096, + -0.058267030864953995 + ]) + ] + + expected_completions = [ + " bright, with endless possibilities for innovation and growth", + " the head of state and head of government of", + ] + + for i in range(num_prompts): + prefill_logprobs_no_batching = torch.tensor( + logprobs_no_batching[i][:prompt_lens[i] - 1]) + decode_logprobs_no_batching = torch.tensor( + logprobs_no_batching[i][prompt_lens[i] - 1:]) + + prefill_logprobs_batching = torch.tensor( + logprobs_batching[i][:prompt_lens[i] - 1]) + decode_logprobs_batching = torch.tensor( + logprobs_batching[i][prompt_lens[i] - 1:]) + + # compare prompt logprobs with initial implementation + torch.testing.assert_close( + prefill_logprobs_no_batching, + prefill_logprobs_ref_initial_no_batching[i], + atol=initial_impl_atol, + rtol=0.0) + torch.testing.assert_close( + prefill_logprobs_batching, + prefill_logprobs_ref_initial_with_batching[i], + atol=initial_impl_atol, + rtol=0.0) + + # compare expected completion + self.assertEqual(completions_batching[i], expected_completions[i]) + self.assertEqual(completions_no_batching[i], + expected_completions[i]) + + # compare decode logprobs with initial implementation + torch.testing.assert_close( + decode_logprobs_no_batching, + decode_logprobs_ref_initial_no_batching[i], + atol=initial_impl_atol, + rtol=0.0) + torch.testing.assert_close( + decode_logprobs_batching, + decode_logprobs_ref_initial_with_batching[i], + atol=initial_impl_atol, + rtol=0.0) + + # compare logprobs with and without batching, tolerace by diff in initial implementation + torch.testing.assert_close(prefill_logprobs_batching, + prefill_logprobs_no_batching, + atol=batching_atol, + rtol=0.0) + torch.testing.assert_close(decode_logprobs_batching, + decode_logprobs_no_batching, + atol=batching_atol, + rtol=0.0) + + # now let's test that decodes match prefill logprobs + text_prompts_with_completions = [ + f"{text_prompts[i]}{completions_batching[i]}" + for i in range(num_prompts) + ] + _, _, full_sequence_logprobs = generate( + model=nemotron_h, + tokenizer=tokenizer, + cache=kv_cache_manager, + text_prompts=text_prompts_with_completions, + tokens_to_generate=0, + device=torch.device("cuda")) + + # compare full sequence logprobs with prefill+decode logprobs, tolerance like mcore tolerance + for i in range(num_prompts): + torch.testing.assert_close(torch.tensor(full_sequence_logprobs[i]), + torch.tensor(logprobs_batching[i]), + atol=mcore_atol, + rtol=0.0) kv_cache_manager.shutdown()