mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5494718][fix] Fix Single GPU Multi-node issue and OOM on DGX Spark (#8514)
Signed-off-by: Simeng Liu <simengl@nvidia.com>
This commit is contained in:
parent
812bc8c954
commit
2b27810198
@ -14,8 +14,10 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm._utils import is_device_integrated
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy)
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
@ -67,6 +69,15 @@ def load_weight_shard(
|
||||
tensor_parallel_mode: Optional[TensorParallelMode] = None,
|
||||
device: torch.device = torch.device('cpu'),
|
||||
) -> torch.Tensor:
|
||||
# Skip device transfers on integrated GPUs to conserve shared memory
|
||||
if weight.device.type != device.type and is_device_integrated():
|
||||
# For integrated GPU systems (e.g., DGX Spark), CPU and GPU share limited physical memory.
|
||||
# Avoiding device transfers reduces memory consumption and unnecessary data copies,
|
||||
# enabling support for larger models on memory-constrained systems.
|
||||
logger.warning(
|
||||
f"[load_weight_shard] Skipping device transfer from {weight.device} to {device} on integrated GPU to conserve shared memory."
|
||||
)
|
||||
device = weight.device
|
||||
if isinstance(weight, torch.Tensor):
|
||||
tensor_shape = weight.shape
|
||||
|
||||
|
||||
@ -577,7 +577,7 @@ def local_mpi_barrier():
|
||||
|
||||
|
||||
def mpi_broadcast(obj, root=0):
|
||||
return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj
|
||||
return mpi_comm().bcast(obj, root) if global_mpi_size() > 1 else obj
|
||||
|
||||
|
||||
def mpi_allgather(obj):
|
||||
@ -1141,17 +1141,6 @@ class KVCacheEventSerializer:
|
||||
}
|
||||
|
||||
|
||||
def is_multi_device_enable():
|
||||
"""
|
||||
This method evaluates if we are running on multiple GPUs and the flag ENABLE_MULTI_DEVICE is set.
|
||||
So we can avoid broadcast calls on single GPU.
|
||||
Issue: https://github.com/NVIDIA/TensorRT-LLM/issues/5927
|
||||
ENABLE_MULTI_DEVICE is true by default when building TensorRT LLM so we need to also check
|
||||
the number of devices
|
||||
"""
|
||||
return local_mpi_size() > 1
|
||||
|
||||
|
||||
def set_prometheus_multiproc_dir() -> object:
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.10/python/sglang/srt/utils.py#L1266
|
||||
global prometheus_multiproc_dir
|
||||
@ -1174,3 +1163,19 @@ def torch_pybind11_abi() -> str:
|
||||
if TORCH_PYBIND11_ABI is None:
|
||||
TORCH_PYBIND11_ABI = f"{torch._C._PYBIND11_COMPILER_TYPE}{torch._C._PYBIND11_STDLIB}{torch._C._PYBIND11_BUILD_ABI}"
|
||||
return TORCH_PYBIND11_ABI
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_device_integrated() -> bool:
|
||||
"""Check if the current GPU device is integrated (shares physical memory with CPU).
|
||||
|
||||
Integrated GPU systems include DGX Spark and other unified memory architectures.
|
||||
This function caches the result to avoid repeated CUDA device property queries.
|
||||
|
||||
Returns:
|
||||
bool: True if the GPU is integrated, False otherwise. Returns False if CUDA
|
||||
is not available.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
return torch.cuda.get_device_properties().is_integrated
|
||||
|
||||
@ -3414,6 +3414,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
|
||||
MODEL_PATH = f"{llm_models_root()}/gpt_oss/gpt-oss-120b"
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5596343")
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype",
|
||||
["auto", pytest.param("fp8", marks=skip_pre_blackwell)])
|
||||
@ -3465,6 +3466,7 @@ class TestGPTOSS(LlmapiAccuracyTestHarness):
|
||||
task = GSM8K(model_name)
|
||||
task.evaluate(llm, is_integration_test=True)
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5596343")
|
||||
@pytest.mark.skip_less_device(4)
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype",
|
||||
@ -3668,6 +3670,7 @@ class TestQwen2_VL_7B(LlmapiAccuracyTestHarness):
|
||||
|
||||
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5601909")
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH,
|
||||
max_num_tokens=16384,
|
||||
|
||||
@ -263,6 +263,7 @@ def test_model(build_google_tests, model, prepare_model, run_model_tests,
|
||||
run_model_tests(model, run_fp8)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5601670")
|
||||
@pytest.mark.parametrize("build_google_tests", ["80", "86", "89", "90"],
|
||||
indirect=True)
|
||||
@pytest.mark.parametrize("model", ["bart", "gpt", "t5"])
|
||||
|
||||
@ -509,6 +509,7 @@ def background_workers(llm_venv, config_file: str, num_ranks: int = None):
|
||||
proc.wait()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5372970")
|
||||
@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
|
||||
indirect=True)
|
||||
def test_workers_conditional_disaggregation(disaggregated_test_root,
|
||||
|
||||
@ -1665,6 +1665,7 @@ def test_openai_lora(llm_root, llm_venv):
|
||||
llm_venv.run_cmd(["-m", "pytest", str(test_root / "_test_openai_lora.py")])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5596377")
|
||||
def test_openai_chat_multimodal_example(llm_root, llm_venv):
|
||||
test_root = unittest_path() / "llmapi" / "apps"
|
||||
llm_venv.run_cmd(
|
||||
|
||||
@ -55,6 +55,7 @@ class TestRpcProxy:
|
||||
|
||||
return proxy
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5579234")
|
||||
@pytest.mark.parametrize("num_reqs", [1, 10])
|
||||
def test_tp1(self, num_reqs):
|
||||
tokenizer = TransformersTokenizer.from_pretrained(model_path)
|
||||
|
||||
@ -98,6 +98,7 @@ class TestRpcWorkerTP1:
|
||||
break
|
||||
assert 0 < len(results) <= 5
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5583261")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("req_count", [10])
|
||||
async def test_main_loop_async(self, req_count: int):
|
||||
@ -175,6 +176,7 @@ class TestRpcWorkerTP1:
|
||||
|
||||
await process_request_streaming()
|
||||
|
||||
@pytest.mark.skip(reason="https://nvbugs/5583261")
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_stats_loop_async(self):
|
||||
await asyncio.sleep(1)
|
||||
@ -227,6 +229,7 @@ class TestRpcWorkerTP2:
|
||||
|
||||
@skip_single_gpu
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.skip(reason="https://nvbugs/5583261")
|
||||
def test_create_shutdown(self):
|
||||
# Invoke setup_engine in rank 0, and that will unblock all the ranks to
|
||||
# invoke setup_engine simultaneously.
|
||||
@ -234,6 +237,7 @@ class TestRpcWorkerTP2:
|
||||
|
||||
@skip_single_gpu
|
||||
@pytest.mark.gpu2
|
||||
@pytest.mark.skip(reason="https://nvbugs/5583261")
|
||||
def test_fetch_responses_sync(self):
|
||||
# Wait a bit to ensure engine is ready
|
||||
time.sleep(1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user