[https://nvbugs/5494718][fix] Fix Single GPU Multi-node issue and OOM on DGX Spark (#8514)

Signed-off-by: Simeng Liu <simengl@nvidia.com>
This commit is contained in:
Simeng Liu 2025-10-24 19:09:07 -07:00 committed by GitHub
parent 812bc8c954
commit 2b27810198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 39 additions and 12 deletions

View File

@ -14,8 +14,10 @@ from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm._utils import is_device_integrated
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
@ -67,6 +69,15 @@ def load_weight_shard(
tensor_parallel_mode: Optional[TensorParallelMode] = None,
device: torch.device = torch.device('cpu'),
) -> torch.Tensor:
# Skip device transfers on integrated GPUs to conserve shared memory
if weight.device.type != device.type and is_device_integrated():
# For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory.
# Avoiding device transfers reduces memory consumption and unnecessary data copies,
# enabling support for larger models on memory-constrained systems.
logger.warning(
f"[load_weight_shard] Skipping device transfer from {weight.device} to {device} on integrated GPU to conserve shared memory."
)
device = weight.device
if isinstance(weight, torch.Tensor):
tensor_shape = weight.shape

View File

@ -577,7 +577,7 @@ def local_mpi_barrier():
def mpi_broadcast(obj, root=0):
return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj
return mpi_comm().bcast(obj, root) if global_mpi_size() > 1 else obj
def mpi_allgather(obj):
@ -1141,17 +1141,6 @@ class KVCacheEventSerializer:
}
def is_multi_device_enable():
"""
This method evaluates if we are running on multiple GPUs and the flag ENABLE_MULTI_DEVICE is set.
So we can avoid broadcast calls on single GPU.
Issue: https://github.com/NVIDIA/TensorRT-LLM/issues/5927
ENABLE_MULTI_DEVICE is true by default when building TensorRT LLM so we need to also check
the number of devices
"""
return local_mpi_size() > 1
def set_prometheus_multiproc_dir() -> object:
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
global prometheus_multiproc_dir
@ -1174,3 +1163,19 @@ def torch_pybind11_abi() -> str:
if TORCH_PYBIND11_ABI is None:
TORCH_PYBIND11_ABI = f"{torch._C._PYBIND11_COMPILER_TYPE}{torch._C._PYBIND11_STDLIB}{torch._C._PYBIND11_BUILD_ABI}"
return TORCH_PYBIND11_ABI
@lru_cache(maxsize=1)
def is_device_integrated() -> bool:
"""Check if the current GPU device is integrated (shares physical memory with CPU).
Integrated GPU systems include DGX Spark and other unified memory architectures.
This function caches the result to avoid repeated CUDA device property queries.
Returns:
bool: True if the GPU is integrated, False otherwise. Returns False if CUDA
is not available.
"""
if not torch.cuda.is_available():
return False
return torch.cuda.get_device_properties().is_integrated

View File

@ -3414,6 +3414,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
MODEL_PATH = f"{llm_models_root()}/gpt_oss/gpt-oss-120b"
@pytest.mark.skip(reason="https://nvbugs/5596343")
@pytest.mark.parametrize(
"kv_cache_dtype",
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
@ -3465,6 +3466,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
task = GSM8K(model_name)
task.evaluate(llm, is_integration_test=True)
@pytest.mark.skip(reason="https://nvbugs/5596343")
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize(
"kv_cache_dtype",
@ -3668,6 +3670,7 @@ class TestQwen2_VL_7B(LlmapiAccuracyTestHarness):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
@pytest.mark.skip(reason="https://nvbugs/5601909")
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
max_num_tokens=16384,

View File

@ -263,6 +263,7 @@ def test_model(build_google_tests, model, prepare_model, run_model_tests,
run_model_tests(model, run_fp8)
@pytest.mark.skip(reason="https://nvbugs/5601670")
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
indirect=True)
@pytest.mark.parametrize("model", ["bart", "gpt", "t5"])

View File

@ -509,6 +509,7 @@ def background_workers(llm_venv, config_file: str, num_ranks: int = None):
proc.wait()
@pytest.mark.skip(reason="https://nvbugs/5372970")
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_workers_conditional_disaggregation(disaggregated_test_root,

View File

@ -1665,6 +1665,7 @@ def test_openai_lora(llm_root, llm_venv):
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])
@pytest.mark.skip(reason="https://nvbugs/5596377")
def test_openai_chat_multimodal_example(llm_root, llm_venv):
test_root = unittest_path() / "llmapi" / "apps"
llm_venv.run_cmd(

View File

@ -55,6 +55,7 @@ class TestRpcProxy:
return proxy
@pytest.mark.skip(reason="https://nvbugs/5579234")
@pytest.mark.parametrize("num_reqs", [1, 10])
def test_tp1(self, num_reqs):
tokenizer = TransformersTokenizer.from_pretrained(model_path)

View File

@ -98,6 +98,7 @@ class TestRpcWorkerTP1:
break
assert 0 < len(results) <= 5
@pytest.mark.skip(reason="https://nvbugs/5583261")
@pytest.mark.asyncio
@pytest.mark.parametrize("req_count", [10])
async def test_main_loop_async(self, req_count: int):
@ -175,6 +176,7 @@ class TestRpcWorkerTP1:
await process_request_streaming()
@pytest.mark.skip(reason="https://nvbugs/5583261")
@pytest.mark.asyncio
async def test_fetch_stats_loop_async(self):
await asyncio.sleep(1)
@ -227,6 +229,7 @@ class TestRpcWorkerTP2:
@skip_single_gpu
@pytest.mark.gpu2
@pytest.mark.skip(reason="https://nvbugs/5583261")
def test_create_shutdown(self):
# Invoke setup_engine in rank 0, and that will unblock all the ranks to
# invoke setup_engine simultaneously.
@ -234,6 +237,7 @@ class TestRpcWorkerTP2:
@skip_single_gpu
@pytest.mark.gpu2
@pytest.mark.skip(reason="https://nvbugs/5583261")
def test_fetch_responses_sync(self):
# Wait a bit to ensure engine is ready
time.sleep(1)