mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 12:12:39 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
502 lines
17 KiB
Python
502 lines
17 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
from parameterized import parameterized
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models.modeling_nvsmall import NVSmallForCausalLM
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
|
from tensorrt_llm.mapping import Mapping
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.llm_data import llm_models_root
|
|
|
|
NVSMALL_MINI_CONFIG = {
|
|
"architectures": ["DeciLMForCausalLM"],
|
|
"attention_bias":
|
|
False,
|
|
"block_configs": [{
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 16,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": None,
|
|
"no_op": False,
|
|
"replace_with_linear": True
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": None,
|
|
"no_op": True,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": None,
|
|
"no_op": False,
|
|
"replace_with_linear": True
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 4,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": None,
|
|
"no_op": True,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 16,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": None,
|
|
"no_op": False,
|
|
"replace_with_linear": True
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": None,
|
|
"no_op": True,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": None,
|
|
"no_op": False,
|
|
"replace_with_linear": True
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 4,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": None,
|
|
"no_op": True,
|
|
"replace_with_linear": False
|
|
}
|
|
}, {
|
|
"attention": {
|
|
"n_heads_in_group": 8,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
},
|
|
"ffn": {
|
|
"ffn_mult": 2.0,
|
|
"no_op": False,
|
|
"replace_with_linear": False
|
|
}
|
|
}],
|
|
"bos_token_id":
|
|
1,
|
|
"eos_token_id":
|
|
2,
|
|
"hidden_act":
|
|
"silu",
|
|
"hidden_size":
|
|
2048,
|
|
"initializer_range":
|
|
0.02,
|
|
"intermediate_size":
|
|
None,
|
|
"max_position_embeddings":
|
|
2048,
|
|
"model_type":
|
|
"deci",
|
|
"num_attention_heads":
|
|
32,
|
|
"num_hidden_layers":
|
|
14,
|
|
"num_key_value_heads":
|
|
None,
|
|
"rms_norm_eps":
|
|
1e-06,
|
|
"rope_scaling":
|
|
None,
|
|
"rope_theta":
|
|
10000.0,
|
|
"tie_word_embeddings":
|
|
False,
|
|
"torch_dtype":
|
|
"bfloat16",
|
|
"use_cache":
|
|
True,
|
|
"vocab_size":
|
|
32128
|
|
}
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class Scenario:
|
|
backend: str
|
|
|
|
def __repr__(self) -> str:
|
|
return f"backend:{self.backend.lower()}"
|
|
|
|
|
|
def reduce_nvsmall_config(mem_for_full_model: int, config_dict: dict[str, Any]):
|
|
_, total_mem = torch.cuda.mem_get_info()
|
|
# scale model down if gpu memory is low
|
|
if total_mem < mem_for_full_model:
|
|
model_fraction = total_mem / mem_for_full_model
|
|
num_layers = int(config_dict["num_hidden_layers"] * model_fraction)
|
|
num_layers = min(num_layers, 32)
|
|
config_dict["num_hidden_layers"] = num_layers
|
|
config_dict["block_configs"] = config_dict["block_configs"][:num_layers]
|
|
|
|
|
|
class TestNVSmall(unittest.TestCase):
|
|
|
|
def test_nvsmall_sanity(self):
|
|
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
|
|
# 8B * sizeof(float16) plus some extra for activations
|
|
mem_for_full_model = (2 + 1) * 8 * 2**(30)
|
|
reduce_nvsmall_config(mem_for_full_model, config_dict)
|
|
if config_dict["num_hidden_layers"] <= 0:
|
|
self.skipTest("Insufficient memory for a single NVSmall layer")
|
|
nvsmall_config = AutoConfig.from_pretrained(
|
|
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
|
|
nvsmall_config = nvsmall_config.from_dict(config_dict)
|
|
|
|
dtype = nvsmall_config.torch_dtype
|
|
device = torch.device('cuda')
|
|
|
|
model_config = ModelConfig(pretrained_config=nvsmall_config)
|
|
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
|
|
|
|
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
|
dtype=torch.int,
|
|
device=device)
|
|
|
|
num_blocks = 1000
|
|
tokens_per_block = 128
|
|
|
|
if dtype == torch.half:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
|
elif dtype == torch.bfloat16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
|
else:
|
|
raise ValueError("Invalid dtype")
|
|
|
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
|
tokens_per_block)
|
|
|
|
num_layers = nvsmall.config.num_hidden_layers
|
|
num_kv_heads = nvsmall.config.num_key_value_heads
|
|
num_heads = nvsmall.config.num_attention_heads
|
|
head_dim = nvsmall.config.hidden_size // num_heads
|
|
max_seq_len = num_blocks * tokens_per_block
|
|
|
|
context_sequence_lengths = [3, 2, 1]
|
|
sequence_lengths = context_sequence_lengths + [1, 1]
|
|
batch_size = len(sequence_lengths)
|
|
past_seen_tokens = [0, 0, 0, 62, 75]
|
|
request_ids = list(range(len(sequence_lengths)))
|
|
token_nums = (torch.tensor(past_seen_tokens) +
|
|
torch.tensor(sequence_lengths)).tolist()
|
|
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
|
|
|
kv_cache_manager = KVCacheManager(
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
|
|
max_seq_len, batch_size, mapping, kv_cache_dtype)
|
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
|
|
|
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
|
attn_metadata = metadata_cls(
|
|
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
|
num_contexts=len(context_sequence_lengths),
|
|
kv_cache_params=KVCacheParams(
|
|
use_cache=True,
|
|
num_cached_tokens_per_seq=past_seen_tokens,
|
|
),
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
prompt_lens=prompt_lens,
|
|
max_num_requests=len(context_sequence_lengths) + 2,
|
|
max_num_tokens=8192,
|
|
)
|
|
|
|
position_ids = []
|
|
for i, tokens in enumerate(past_seen_tokens):
|
|
seq_len = context_sequence_lengths[i] if i < len(
|
|
context_sequence_lengths) else 1
|
|
position_id = torch.arange(tokens,
|
|
tokens + seq_len,
|
|
device=input_ids.device)
|
|
position_ids.append(position_id)
|
|
|
|
position_ids = torch.cat(position_ids).unsqueeze(0)
|
|
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = nvsmall.forward(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attn_metadata=attn_metadata)
|
|
|
|
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
|
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = nvsmall.forward(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attn_metadata=attn_metadata,
|
|
return_context_logits=True)
|
|
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
|
|
|
kv_cache_manager.shutdown()
|
|
|
|
@parameterized.expand([
|
|
Scenario(backend="VANILLA"),
|
|
Scenario(backend="FLASHINFER"),
|
|
Scenario(backend="TRTLLM"),
|
|
], lambda testcase_func, param_num, param:
|
|
f"{testcase_func.__name__}[{param.args[0]}]")
|
|
@torch.no_grad()
|
|
def test_nvsmall_allclose_to_hf(self, scenario: Scenario) -> None:
|
|
"""
|
|
Compare output to HF
|
|
"""
|
|
backend = scenario.backend
|
|
metadata_cls = get_attention_backend(backend).Metadata
|
|
|
|
torch.random.manual_seed(0)
|
|
config_dict = deepcopy(NVSMALL_MINI_CONFIG)
|
|
# 8B * sizeof(float16) plus some extra for activations
|
|
# times 2, since we'll need 2 of these
|
|
mem_for_full_model = (2 + 1) * 8 * 2**(30) * 4
|
|
reduce_nvsmall_config(mem_for_full_model, config_dict)
|
|
if config_dict["num_hidden_layers"] <= 0:
|
|
self.skipTest("Insufficient memory for a single NVSmall layer")
|
|
nvsmall_config = AutoConfig.from_pretrained(
|
|
"nvidia/Llama-3_1-Nemotron-51B-Instruct", trust_remote_code=True)
|
|
nvsmall_config = nvsmall_config.from_dict(config_dict)
|
|
dtype = nvsmall_config.torch_dtype
|
|
device = torch.device('cuda')
|
|
|
|
hf_nvsmall = AutoModelForCausalLM.from_pretrained(
|
|
llm_models_root() / "nemotron-nas/Llama-3_1-Nemotron-51B-Instruct",
|
|
trust_remote_code=True,
|
|
device_map="meta")
|
|
hf_nvsmall = hf_nvsmall.__class__(nvsmall_config).to(dtype).to(
|
|
device).eval()
|
|
# This line populates the "variable" field in the NEED_SETUP_CACHE_CLASSES_MAPPING dict
|
|
hf_nvsmall._prepare_generation_config(None)
|
|
# And this line is the only way to access the only concrete Cache class DeciLMForCausalLM accepts
|
|
VariableCache = NEED_SETUP_CACHE_CLASSES_MAPPING["variable"]
|
|
|
|
model_config = ModelConfig(pretrained_config=nvsmall_config,
|
|
attn_backend=backend)
|
|
nvsmall = NVSmallForCausalLM(model_config).to(dtype).to(device)
|
|
nvsmall.load_weights(hf_nvsmall.state_dict())
|
|
|
|
num_blocks = 1
|
|
tokens_per_block = 128
|
|
|
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
|
tokens_per_block)
|
|
|
|
num_layers = nvsmall.config.num_hidden_layers
|
|
num_kv_heads = nvsmall.config.num_key_value_heads
|
|
num_heads = nvsmall.config.num_attention_heads
|
|
head_dim = nvsmall.config.hidden_size // num_heads
|
|
max_seq_len = num_blocks * tokens_per_block
|
|
batch_size = 1
|
|
|
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
|
if dtype == torch.half:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
|
elif dtype == torch.bfloat16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
|
else:
|
|
raise ValueError("Invalid dtype")
|
|
|
|
kv_cache_manager = KVCacheManager(
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
|
|
max_seq_len, batch_size, mapping, kv_cache_dtype)
|
|
|
|
# context
|
|
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
|
dtype=torch.int,
|
|
device=device)
|
|
|
|
num_cached_tokens_per_seq = [0]
|
|
request_ids = [1]
|
|
token_nums = [input_ids.size(-1)]
|
|
prompt_lens = [input_ids.size(-1)]
|
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
|
|
|
attn_metadata = metadata_cls(
|
|
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
|
num_contexts=1,
|
|
kv_cache_params=KVCacheParams(
|
|
use_cache=True,
|
|
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
|
),
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
prompt_lens=prompt_lens,
|
|
max_num_requests=1,
|
|
max_num_tokens=8192,
|
|
)
|
|
|
|
position_ids = [torch.arange(0, input_ids.size(-1))]
|
|
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
|
# And, lastly, this is the simplest way of creating a Cache that `hf_nvsmall` will accept
|
|
past_key_values = VariableCache(config=nvsmall_config,
|
|
dtype=dtype,
|
|
batch_size=1)
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = nvsmall.forward(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attn_metadata=attn_metadata)
|
|
ref = hf_nvsmall.forward(input_ids=input_ids.unsqueeze(0),
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
use_cache=True)
|
|
|
|
torch.testing.assert_close(logits,
|
|
ref.logits[:, -1].float(),
|
|
atol=0.1,
|
|
rtol=0.1)
|
|
|
|
# gen
|
|
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
|
|
|
|
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
|
|
|
attn_metadata = metadata_cls(
|
|
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
|
num_contexts=0,
|
|
kv_cache_params=KVCacheParams(
|
|
use_cache=True,
|
|
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
|
),
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
prompt_lens=prompt_lens,
|
|
max_num_requests=1,
|
|
max_num_tokens=8192,
|
|
)
|
|
|
|
gen_position_ids = [
|
|
torch.arange(input_ids.size(-1),
|
|
input_ids.size(-1) + gen_input_ids.size(-1))
|
|
]
|
|
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = nvsmall.forward(input_ids=gen_input_ids,
|
|
position_ids=gen_position_ids,
|
|
attn_metadata=attn_metadata)
|
|
ref = hf_nvsmall.forward(input_ids=gen_input_ids.unsqueeze(0),
|
|
position_ids=gen_position_ids,
|
|
past_key_values=ref.past_key_values,
|
|
use_cache=True)
|
|
|
|
torch.testing.assert_close(logits,
|
|
ref.logits[:, -1].float(),
|
|
atol=0.1,
|
|
rtol=0.1)
|
|
|
|
kv_cache_manager.shutdown()
|