mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 0087e959f2 into 6df2c8a074
This commit is contained in:
commit
d650438b8d
@ -66,6 +66,9 @@ cache_transceiver_config:
|
||||
disable_overlap_scheduler: true
|
||||
# Ray executor configuration
|
||||
orchestrator_type: "ray"
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 128
|
||||
EOF
|
||||
else
|
||||
cat > extra_llm_config.yaml << EOF
|
||||
@ -75,6 +78,10 @@ cache_transceiver_config:
|
||||
max_tokens_in_buffer: 2048
|
||||
disable_overlap_scheduler: true
|
||||
# Using default executor MPI (no orchestrator_type specified)
|
||||
# Memory-saving parameters
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 128
|
||||
EOF
|
||||
fi
|
||||
|
||||
@ -91,6 +98,9 @@ context_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: $TP_SIZE
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 128
|
||||
kv_cache_config:
|
||||
free_gpu_memory_fraction: 0.2
|
||||
cache_transceiver_config:
|
||||
@ -101,6 +111,9 @@ generation_servers:
|
||||
num_instances: 1
|
||||
tensor_parallel_size: $TP_SIZE
|
||||
pipeline_parallel_size: 1
|
||||
max_batch_size: 1
|
||||
max_num_tokens: 512
|
||||
max_seq_len: 128
|
||||
cache_transceiver_config:
|
||||
backend: "UCX"
|
||||
urls:
|
||||
@ -125,18 +138,31 @@ fi
|
||||
|
||||
# Launching context servers
|
||||
echo "Launching context servers..."
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
fi
|
||||
CTX_GPUS=$(seq -s, 0 $((TP_SIZE - 1)))
|
||||
echo "Context server using GPUs: $CTX_GPUS (via CUDA_VISIBLE_DEVICES)"
|
||||
(
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=$CTX_GPUS
|
||||
fi
|
||||
|
||||
trtllm-serve $MODEL_DIR --host localhost --tp_size $TP_SIZE --port 8001 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --config extra_llm_config.yaml &> output_ctx0 &
|
||||
trtllm-serve $MODEL_DIR --host localhost --tp_size $TP_SIZE --port 8001 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml
|
||||
) &> output_ctx0 &
|
||||
CTX_PID=$!
|
||||
echo "Context server started with PID: $CTX_PID"
|
||||
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
fi
|
||||
# Launching generation servers
|
||||
echo "Launching generation servers..."
|
||||
trtllm-serve $MODEL_DIR --host localhost --tp_size $TP_SIZE --port 8002 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --config extra_llm_config.yaml &> output_gen0 &
|
||||
GEN_GPUS=$(seq -s, $TP_SIZE $((2 * TP_SIZE - 1)))
|
||||
echo "Generation server using GPUs: $GEN_GPUS (via CUDA_VISIBLE_DEVICES)"
|
||||
(
|
||||
if [[ "$BACKEND" == "mpi" ]]; then
|
||||
export CUDA_VISIBLE_DEVICES=$GEN_GPUS
|
||||
fi
|
||||
|
||||
trtllm-serve $MODEL_DIR --host localhost --tp_size $TP_SIZE --port 8002 --kv_cache_free_gpu_memory_fraction 0.15 --backend pytorch --extra_llm_api_options extra_llm_config.yaml
|
||||
) &> output_gen0 &
|
||||
GEN_PID=$!
|
||||
echo "Generation server started with PID: $GEN_PID"
|
||||
|
||||
# Launching disaggregated server
|
||||
echo "Launching disaggregated server..."
|
||||
|
||||
@ -2591,7 +2591,8 @@ class PyExecutor:
|
||||
f'Send first token response for request {req.py_request_id}'
|
||||
)
|
||||
response = req.create_response(False, self.dist.rank)
|
||||
new_responses.append((req.py_request_id, response))
|
||||
if response:
|
||||
new_responses.append((req.py_request_id, response))
|
||||
|
||||
self._enqueue_responses(new_responses)
|
||||
|
||||
|
||||
@ -62,8 +62,6 @@ def test_llm_inference_distributed_ray(ray_example_root, llm_venv, tp_size,
|
||||
@pytest.mark.skip_less_device(2)
|
||||
@pytest.mark.parametrize("tp_size", [1, 2], ids=["tp1", "tp2"])
|
||||
def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
|
||||
if tp_size == 1:
|
||||
pytest.skip("https://nvbugs/5682551")
|
||||
|
||||
if get_device_count() < tp_size * 2:
|
||||
pytest.skip(f"Need {tp_size * 2} GPUs.")
|
||||
@ -101,8 +99,8 @@ def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
|
||||
stderr=subprocess.PIPE,
|
||||
env=env_copy,
|
||||
):
|
||||
assert wait_for_server("localhost", 8000, timeout_seconds=180), \
|
||||
"Disaggregated server failed to start within 3 minutes"
|
||||
assert wait_for_server("localhost", 8000, timeout_seconds=300), \
|
||||
"Disaggregated server failed to start within 5 minutes"
|
||||
|
||||
result = subprocess.run([
|
||||
"curl", "-sS", "-w", "\n%{http_code}",
|
||||
@ -112,7 +110,7 @@ def test_ray_disaggregated_serving(ray_example_root, llm_venv, tp_size):
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30)
|
||||
timeout=60)
|
||||
|
||||
*body_lines, status_line = result.stdout.strip().splitlines()
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
@ -303,7 +303,6 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_d
|
||||
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
|
||||
examples/serve/test_serve.py::test_config_file_loading[--config] SKIP (https://nvbugs/5754977)
|
||||
full:RTXPro6000D/accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugspro.nvidia.com/bug/5794313)
|
||||
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
|
||||
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
|
||||
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)
|
||||
triton_server/test_triton.py::test_gpt_gather_logits[gpt-gather-logits] SKIP (https://nvbugs/5766960)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user