mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5549081][fix] Fix device id assignment for some vision models (#8070)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
parent
bd3d0ad233
commit
726ac07cc0
@ -726,7 +726,8 @@ class HCXVisionModel:
|
||||
self.vision_config = self.pretrained_config.vision_config
|
||||
|
||||
model_path = self.pretrained_config._name_or_path
|
||||
self.device = f"cuda:{model_config.mapping.rank}"
|
||||
# TODO: use config.mapping.get_local_rank() instead
|
||||
self.device = f"cuda:{torch.cuda.current_device()}"
|
||||
|
||||
hf_model_config = AutoConfig.from_pretrained(model_path,
|
||||
trust_remote_code=True)
|
||||
|
||||
@ -998,7 +998,8 @@ class Llama4VisionEncoder(nn.Module):
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.pretrained_config = model_config.pretrained_config
|
||||
self.device = f"cuda:{model_config.mapping.rank}"
|
||||
# TODO: use config.mapping.get_local_rank() instead
|
||||
self.device = f"cuda:{torch.cuda.current_device()}"
|
||||
|
||||
self.dtype = self.pretrained_config.text_config.torch_dtype
|
||||
|
||||
|
||||
@ -295,7 +295,8 @@ class LlavaNextVisionModel(nn.Module):
|
||||
super().__init__()
|
||||
self.model_config = model_config
|
||||
self.pretrained_config = model_config.pretrained_config
|
||||
self.device = f"cuda:{model_config.mapping.rank}"
|
||||
# TODO: use config.mapping.get_local_rank() instead
|
||||
self.device = f"cuda:{torch.cuda.current_device()}"
|
||||
model_path = self.pretrained_config._name_or_path
|
||||
|
||||
# Determine the actual local path for model files
|
||||
|
||||
Loading…
Reference in New Issue
Block a user