mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[fix] speedup modeling unittests (#5579)
Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
This commit is contained in:
parent
4fef14da56
commit
1db63c2546
@ -7,7 +7,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import LlamaConfig
|
||||
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
|
||||
from utils.util import getSMVersion
|
||||
from utils.util import default_dtype, getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
@ -97,9 +97,10 @@ class TestLlama(unittest.TestCase):
|
||||
dtype = llama_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
quant_config=quant_config)
|
||||
llama = LlamaForCausalLM(model_config).to(device)
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
quant_config=quant_config)
|
||||
llama = LlamaForCausalLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
@ -217,12 +218,14 @@ class TestLlama(unittest.TestCase):
|
||||
dtype = llama_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
hf_llama = HFLlamaForCausalLM(llama_config).to(dtype).to(device).eval()
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
hf_llama = HFLlamaForCausalLM(llama_config).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
attn_backend=backend)
|
||||
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
|
||||
llama.load_weights(hf_llama.state_dict())
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
attn_backend=backend)
|
||||
|
||||
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
|
||||
llama.load_weights(hf_llama.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
|
||||
@ -8,7 +8,7 @@ from transformers import Llama4Config
|
||||
from transformers import \
|
||||
Llama4ForConditionalGeneration as HFLlama4ForConditionalGeneration
|
||||
from transformers.cache_utils import DynamicCache
|
||||
from utils.util import getSMVersion
|
||||
from utils.util import default_dtype, getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
@ -152,11 +152,12 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
dtype = llama_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
quant_config=quant_config)
|
||||
model_config.pytorch_backend_config = PyTorchConfig(
|
||||
enable_min_latency=enable_min_latency)
|
||||
llama = Llama4ForConditionalGeneration(model_config).to(device)
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
quant_config=quant_config)
|
||||
model_config.pytorch_backend_config = PyTorchConfig(
|
||||
enable_min_latency=enable_min_latency)
|
||||
llama = Llama4ForConditionalGeneration(model_config)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
@ -275,16 +276,15 @@ class TestLlama4MinLatency(unittest.TestCase):
|
||||
dtype = llama_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
hf_llama = HFLlama4ForConditionalGeneration(llama_config).to(dtype).to(
|
||||
device).eval()
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
hf_llama = HFLlama4ForConditionalGeneration(llama_config).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
attn_backend=attention_backend)
|
||||
model_config.pytorch_backend_config = PyTorchConfig(
|
||||
enable_min_latency=enable_min_latency)
|
||||
llama = Llama4ForConditionalGeneration(model_config).to(dtype).to(
|
||||
device)
|
||||
llama.load_weights(hf_llama.state_dict())
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
attn_backend=attention_backend)
|
||||
model_config.pytorch_backend_config = PyTorchConfig(
|
||||
enable_min_latency=enable_min_latency)
|
||||
llama = Llama4ForConditionalGeneration(model_config)
|
||||
llama.load_weights(hf_llama.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import MixtralConfig
|
||||
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
|
||||
from utils.util import getSMVersion
|
||||
from utils.util import default_dtype, getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
@ -82,9 +82,10 @@ class TestMixtral(unittest.TestCase):
|
||||
dtype = mixtral_config.torch_dtype
|
||||
device = torch.device("cuda")
|
||||
|
||||
model_config = ModelConfig(pretrained_config=mixtral_config,
|
||||
quant_config=quant_config)
|
||||
mixtral = MixtralForCausalLM(model_config).to(device)
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=mixtral_config,
|
||||
quant_config=quant_config)
|
||||
mixtral = MixtralForCausalLM(model_config)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int32,
|
||||
@ -199,13 +200,13 @@ class TestMixtral(unittest.TestCase):
|
||||
dtype = mixtral_config.torch_dtype
|
||||
device = torch.device("cuda")
|
||||
|
||||
hf_mixtral = HFMixtralForCausalLM(mixtral_config).to(dtype).to(
|
||||
device).eval()
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
hf_mixtral = HFMixtralForCausalLM(mixtral_config).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=mixtral_config,
|
||||
attn_backend=backend)
|
||||
mixtral = MixtralForCausalLM(model_config).to(device)
|
||||
mixtral.load_weights(hf_mixtral.state_dict())
|
||||
model_config = ModelConfig(pretrained_config=mixtral_config,
|
||||
attn_backend=backend)
|
||||
mixtral = MixtralForCausalLM(model_config)
|
||||
mixtral.load_weights(hf_mixtral.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import NemotronConfig
|
||||
from transformers import NemotronForCausalLM as HFNemotronForCausalLM
|
||||
from utils.util import getSMVersion
|
||||
from utils.util import default_dtype, getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
@ -88,10 +88,11 @@ class TestNemotron(unittest.TestCase):
|
||||
dtype = nemotron_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
model_config = ModelConfig(pretrained_config=nemotron_config,
|
||||
quant_config=quant_config,
|
||||
attn_backend="TRTLLM")
|
||||
nemotron = NemotronForCausalLM(model_config).to(device)
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=nemotron_config,
|
||||
quant_config=quant_config,
|
||||
attn_backend="TRTLLM")
|
||||
nemotron = NemotronForCausalLM(model_config)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
@ -210,13 +211,13 @@ class TestNemotron(unittest.TestCase):
|
||||
dtype = nemotron_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
hf_nemotron = HFNemotronForCausalLM(nemotron_config).to(dtype).to(
|
||||
device).eval()
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
hf_nemotron = HFNemotronForCausalLM(nemotron_config).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=nemotron_config,
|
||||
attn_backend=backend)
|
||||
nemotron = NemotronForCausalLM(model_config).to(dtype).to(device)
|
||||
nemotron.load_weights(hf_nemotron.state_dict())
|
||||
model_config = ModelConfig(pretrained_config=nemotron_config,
|
||||
attn_backend=backend)
|
||||
nemotron = NemotronForCausalLM(model_config)
|
||||
nemotron.load_weights(hf_nemotron.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from difflib import SequenceMatcher
|
||||
from pathlib import Path
|
||||
|
||||
@ -368,3 +369,11 @@ def similar(a, b, threshold=0.8):
|
||||
def get_project_root(test_file: str) -> Path:
|
||||
return next(p for p in Path(test_file).resolve().parents
|
||||
if (p / 'tests').is_dir() and (p / "tensorrt_llm").is_dir())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def default_dtype(dtype: torch.dtype):
|
||||
cur_default = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(cur_default)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user