[https://nvbugs/5549081][fix] Fix device id assignment for some vision models (#8070)

Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
Chang Liu 2025-10-01 20:28:05 -07:00 committed by GitHub
parent bd3d0ad233
commit 726ac07cc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 3 deletions

View File

@ -726,7 +726,8 @@ class HCXVisionModel:
self.vision_config = self.pretrained_config.vision_config
model_path = self.pretrained_config._name_or_path
self.device = f"cuda:{model_config.mapping.rank}"
# TODO: use config.mapping.get_local_rank() instead
self.device = f"cuda:{torch.cuda.current_device()}"
hf_model_config = AutoConfig.from_pretrained(model_path,
trust_remote_code=True)

View File

@ -998,7 +998,8 @@ class Llama4VisionEncoder(nn.Module):
**kwargs):
super().__init__()
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
# TODO: use config.mapping.get_local_rank() instead
self.device = f"cuda:{torch.cuda.current_device()}"
self.dtype = self.pretrained_config.text_config.torch_dtype

View File

@ -295,7 +295,8 @@ class LlavaNextVisionModel(nn.Module):
super().__init__()
self.model_config = model_config
self.pretrained_config = model_config.pretrained_config
self.device = f"cuda:{model_config.mapping.rank}"
# TODO: use config.mapping.get_local_rank() instead
self.device = f"cuda:{torch.cuda.current_device()}"
model_path = self.pretrained_config._name_or_path
# Determine the actual local path for model files