[fix] speedup modeling unittests (#5579)

Signed-off-by: Omer Ullman Argov <118735753+omera-nv@users.noreply.github.com>
This commit is contained in:
Omer Ullman Argov 2025-06-30 06:30:45 +03:00 committed by GitHub
parent 4fef14da56
commit 1db63c2546
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 45 deletions

View File

@ -7,7 +7,7 @@ import torch
from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
from utils.util import getSMVersion
from utils.util import default_dtype, getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
@ -97,9 +97,10 @@ class TestLlama(unittest.TestCase):
dtype = llama_config.torch_dtype
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=llama_config,
quant_config=quant_config)
llama = LlamaForCausalLM(model_config).to(device)
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=llama_config,
quant_config=quant_config)
llama = LlamaForCausalLM(model_config).to(device)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
@ -217,12 +218,14 @@ class TestLlama(unittest.TestCase):
dtype = llama_config.torch_dtype
device = torch.device('cuda')
hf_llama = HFLlamaForCausalLM(llama_config).to(dtype).to(device).eval()
with torch.device(device), default_dtype(dtype):
hf_llama = HFLlamaForCausalLM(llama_config).eval()
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend=backend)
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
llama.load_weights(hf_llama.state_dict())
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend=backend)
llama = LlamaForCausalLM(model_config).to(dtype).to(device)
llama.load_weights(hf_llama.state_dict())
num_blocks = 1
tokens_per_block = 128

View File

@ -8,7 +8,7 @@ from transformers import Llama4Config
from transformers import \
Llama4ForConditionalGeneration as HFLlama4ForConditionalGeneration
from transformers.cache_utils import DynamicCache
from utils.util import getSMVersion
from utils.util import default_dtype, getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
@ -152,11 +152,12 @@ class TestLlama4MinLatency(unittest.TestCase):
dtype = llama_config.torch_dtype
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=llama_config,
quant_config=quant_config)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
llama = Llama4ForConditionalGeneration(model_config).to(device)
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=llama_config,
quant_config=quant_config)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
llama = Llama4ForConditionalGeneration(model_config)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
@ -275,16 +276,15 @@ class TestLlama4MinLatency(unittest.TestCase):
dtype = llama_config.torch_dtype
device = torch.device('cuda')
hf_llama = HFLlama4ForConditionalGeneration(llama_config).to(dtype).to(
device).eval()
with torch.device(device), default_dtype(dtype):
hf_llama = HFLlama4ForConditionalGeneration(llama_config).eval()
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend=attention_backend)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
llama = Llama4ForConditionalGeneration(model_config).to(dtype).to(
device)
llama.load_weights(hf_llama.state_dict())
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend=attention_backend)
model_config.pytorch_backend_config = PyTorchConfig(
enable_min_latency=enable_min_latency)
llama = Llama4ForConditionalGeneration(model_config)
llama.load_weights(hf_llama.state_dict())
num_blocks = 1
tokens_per_block = 128

View File

@ -6,7 +6,7 @@ import torch
from parameterized import parameterized
from transformers import MixtralConfig
from transformers import MixtralForCausalLM as HFMixtralForCausalLM
from utils.util import getSMVersion
from utils.util import default_dtype, getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
@ -82,9 +82,10 @@ class TestMixtral(unittest.TestCase):
dtype = mixtral_config.torch_dtype
device = torch.device("cuda")
model_config = ModelConfig(pretrained_config=mixtral_config,
quant_config=quant_config)
mixtral = MixtralForCausalLM(model_config).to(device)
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=mixtral_config,
quant_config=quant_config)
mixtral = MixtralForCausalLM(model_config)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int32,
@ -199,13 +200,13 @@ class TestMixtral(unittest.TestCase):
dtype = mixtral_config.torch_dtype
device = torch.device("cuda")
hf_mixtral = HFMixtralForCausalLM(mixtral_config).to(dtype).to(
device).eval()
with torch.device(device), default_dtype(dtype):
hf_mixtral = HFMixtralForCausalLM(mixtral_config).eval()
model_config = ModelConfig(pretrained_config=mixtral_config,
attn_backend=backend)
mixtral = MixtralForCausalLM(model_config).to(device)
mixtral.load_weights(hf_mixtral.state_dict())
model_config = ModelConfig(pretrained_config=mixtral_config,
attn_backend=backend)
mixtral = MixtralForCausalLM(model_config)
mixtral.load_weights(hf_mixtral.state_dict())
num_blocks = 1
tokens_per_block = 128

View File

@ -7,7 +7,7 @@ import torch
from parameterized import parameterized
from transformers import NemotronConfig
from transformers import NemotronForCausalLM as HFNemotronForCausalLM
from utils.util import getSMVersion
from utils.util import default_dtype, getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
@ -88,10 +88,11 @@ class TestNemotron(unittest.TestCase):
dtype = nemotron_config.torch_dtype
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=nemotron_config,
quant_config=quant_config,
attn_backend="TRTLLM")
nemotron = NemotronForCausalLM(model_config).to(device)
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=nemotron_config,
quant_config=quant_config,
attn_backend="TRTLLM")
nemotron = NemotronForCausalLM(model_config)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
@ -210,13 +211,13 @@ class TestNemotron(unittest.TestCase):
dtype = nemotron_config.torch_dtype
device = torch.device('cuda')
hf_nemotron = HFNemotronForCausalLM(nemotron_config).to(dtype).to(
device).eval()
with torch.device(device), default_dtype(dtype):
hf_nemotron = HFNemotronForCausalLM(nemotron_config).eval()
model_config = ModelConfig(pretrained_config=nemotron_config,
attn_backend=backend)
nemotron = NemotronForCausalLM(model_config).to(dtype).to(device)
nemotron.load_weights(hf_nemotron.state_dict())
model_config = ModelConfig(pretrained_config=nemotron_config,
attn_backend=backend)
nemotron = NemotronForCausalLM(model_config)
nemotron.load_weights(hf_nemotron.state_dict())
num_blocks = 1
tokens_per_block = 128

View File

@ -1,5 +1,6 @@
import os
import unittest
from contextlib import contextmanager
from difflib import SequenceMatcher
from pathlib import Path
@ -368,3 +369,11 @@ def similar(a, b, threshold=0.8):
def get_project_root(test_file: str) -> Path:
return next(p for p in Path(test_file).resolve().parents
if (p / 'tests').is_dir() and (p / "tensorrt_llm").is_dir())
@contextmanager
def default_dtype(dtype: torch.dtype):
cur_default = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(cur_default)