diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index c3fed8be6f..e72957997f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -7,7 +7,7 @@ import torch from parameterized import parameterized from transformers import LlamaConfig from transformers import LlamaForCausalLM as HFLlamaForCausalLM -from utils.util import getSMVersion +from utils.util import default_dtype, getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend @@ -97,9 +97,10 @@ class TestLlama(unittest.TestCase): dtype = llama_config.torch_dtype device = torch.device('cuda') - model_config = ModelConfig(pretrained_config=llama_config, - quant_config=quant_config) - llama = LlamaForCausalLM(model_config).to(device) + with torch.device(device), default_dtype(dtype): + model_config = ModelConfig(pretrained_config=llama_config, + quant_config=quant_config) + llama = LlamaForCausalLM(model_config).to(device) input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, @@ -217,12 +218,14 @@ class TestLlama(unittest.TestCase): dtype = llama_config.torch_dtype device = torch.device('cuda') - hf_llama = HFLlamaForCausalLM(llama_config).to(dtype).to(device).eval() + with torch.device(device), default_dtype(dtype): + hf_llama = HFLlamaForCausalLM(llama_config).eval() - model_config = ModelConfig(pretrained_config=llama_config, - attn_backend=backend) - llama = LlamaForCausalLM(model_config).to(dtype).to(device) - llama.load_weights(hf_llama.state_dict()) + model_config = ModelConfig(pretrained_config=llama_config, + attn_backend=backend) + + llama = LlamaForCausalLM(model_config).to(dtype).to(device) + llama.load_weights(hf_llama.state_dict()) num_blocks = 1 tokens_per_block = 128 diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index 4711155d1b..54319371a8 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -8,7 +8,7 @@ from transformers import Llama4Config from transformers import \ Llama4ForConditionalGeneration as HFLlama4ForConditionalGeneration from transformers.cache_utils import DynamicCache -from utils.util import getSMVersion +from utils.util import default_dtype, getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend @@ -152,11 +152,12 @@ class TestLlama4MinLatency(unittest.TestCase): dtype = llama_config.torch_dtype device = torch.device('cuda') - model_config = ModelConfig(pretrained_config=llama_config, - quant_config=quant_config) - model_config.pytorch_backend_config = PyTorchConfig( - enable_min_latency=enable_min_latency) - llama = Llama4ForConditionalGeneration(model_config).to(device) + with torch.device(device), default_dtype(dtype): + model_config = ModelConfig(pretrained_config=llama_config, + quant_config=quant_config) + model_config.pytorch_backend_config = PyTorchConfig( + enable_min_latency=enable_min_latency) + llama = Llama4ForConditionalGeneration(model_config) input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, @@ -275,16 +276,15 @@ class TestLlama4MinLatency(unittest.TestCase): dtype = llama_config.torch_dtype device = torch.device('cuda') - hf_llama = HFLlama4ForConditionalGeneration(llama_config).to(dtype).to( - device).eval() + with torch.device(device), default_dtype(dtype): + hf_llama = HFLlama4ForConditionalGeneration(llama_config).eval() - model_config = ModelConfig(pretrained_config=llama_config, - attn_backend=attention_backend) - model_config.pytorch_backend_config = PyTorchConfig( - enable_min_latency=enable_min_latency) - llama = Llama4ForConditionalGeneration(model_config).to(dtype).to( - device) - llama.load_weights(hf_llama.state_dict()) + model_config = ModelConfig(pretrained_config=llama_config, + attn_backend=attention_backend) + model_config.pytorch_backend_config = PyTorchConfig( + enable_min_latency=enable_min_latency) + llama = Llama4ForConditionalGeneration(model_config) + llama.load_weights(hf_llama.state_dict()) num_blocks = 1 tokens_per_block = 128 diff --git a/tests/unittest/_torch/modeling/test_modeling_mixtral.py b/tests/unittest/_torch/modeling/test_modeling_mixtral.py index c2b07645cb..edbcf1efd2 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mixtral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mixtral.py @@ -6,7 +6,7 @@ import torch from parameterized import parameterized from transformers import MixtralConfig from transformers import MixtralForCausalLM as HFMixtralForCausalLM -from utils.util import getSMVersion +from utils.util import default_dtype, getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend @@ -82,9 +82,10 @@ class TestMixtral(unittest.TestCase): dtype = mixtral_config.torch_dtype device = torch.device("cuda") - model_config = ModelConfig(pretrained_config=mixtral_config, - quant_config=quant_config) - mixtral = MixtralForCausalLM(model_config).to(device) + with torch.device(device), default_dtype(dtype): + model_config = ModelConfig(pretrained_config=mixtral_config, + quant_config=quant_config) + mixtral = MixtralForCausalLM(model_config) input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int32, @@ -199,13 +200,13 @@ class TestMixtral(unittest.TestCase): dtype = mixtral_config.torch_dtype device = torch.device("cuda") - hf_mixtral = HFMixtralForCausalLM(mixtral_config).to(dtype).to( - device).eval() + with torch.device(device), default_dtype(dtype): + hf_mixtral = HFMixtralForCausalLM(mixtral_config).eval() - model_config = ModelConfig(pretrained_config=mixtral_config, - attn_backend=backend) - mixtral = MixtralForCausalLM(model_config).to(device) - mixtral.load_weights(hf_mixtral.state_dict()) + model_config = ModelConfig(pretrained_config=mixtral_config, + attn_backend=backend) + mixtral = MixtralForCausalLM(model_config) + mixtral.load_weights(hf_mixtral.state_dict()) num_blocks = 1 tokens_per_block = 128 diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron.py b/tests/unittest/_torch/modeling/test_modeling_nemotron.py index 81bc6509e7..a17b050ec0 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron.py @@ -7,7 +7,7 @@ import torch from parameterized import parameterized from transformers import NemotronConfig from transformers import NemotronForCausalLM as HFNemotronForCausalLM -from utils.util import getSMVersion +from utils.util import default_dtype, getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend @@ -88,10 +88,11 @@ class TestNemotron(unittest.TestCase): dtype = nemotron_config.torch_dtype device = torch.device('cuda') - model_config = ModelConfig(pretrained_config=nemotron_config, - quant_config=quant_config, - attn_backend="TRTLLM") - nemotron = NemotronForCausalLM(model_config).to(device) + with torch.device(device), default_dtype(dtype): + model_config = ModelConfig(pretrained_config=nemotron_config, + quant_config=quant_config, + attn_backend="TRTLLM") + nemotron = NemotronForCausalLM(model_config) input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500], dtype=torch.int, @@ -210,13 +211,13 @@ class TestNemotron(unittest.TestCase): dtype = nemotron_config.torch_dtype device = torch.device('cuda') - hf_nemotron = HFNemotronForCausalLM(nemotron_config).to(dtype).to( - device).eval() + with torch.device(device), default_dtype(dtype): + hf_nemotron = HFNemotronForCausalLM(nemotron_config).eval() - model_config = ModelConfig(pretrained_config=nemotron_config, - attn_backend=backend) - nemotron = NemotronForCausalLM(model_config).to(dtype).to(device) - nemotron.load_weights(hf_nemotron.state_dict()) + model_config = ModelConfig(pretrained_config=nemotron_config, + attn_backend=backend) + nemotron = NemotronForCausalLM(model_config) + nemotron.load_weights(hf_nemotron.state_dict()) num_blocks = 1 tokens_per_block = 128 diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index 31dcb8efd9..e34370e9eb 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -1,5 +1,6 @@ import os import unittest +from contextlib import contextmanager from difflib import SequenceMatcher from pathlib import Path @@ -368,3 +369,11 @@ def similar(a, b, threshold=0.8): def get_project_root(test_file: str) -> Path: return next(p for p in Path(test_file).resolve().parents if (p / 'tests').is_dir() and (p / "tensorrt_llm").is_dir()) + + +@contextmanager +def default_dtype(dtype: torch.dtype): + cur_default = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(cur_default)