mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[https://nvbugs/5777041][fix] fix AutoDeploy ep sharding test (#10460)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
This commit is contained in:
parent
94c7b69048
commit
15b43e8a14
@ -322,8 +322,6 @@ accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_laten
|
||||
triton_server/test_triton.py::test_llava_onevision[llava_onevision] SKIP (https://nvbugs/5775205)
|
||||
triton_server/test_triton.py::test_gpt_ib_lad[gpt-ib-lad] SKIP (https://nvbugs/5775223)
|
||||
unittest/_torch/modules/test_fused_moe.py::test_fused_moe_fp8_blockwise_cute_dsl_multi_gpu[MoEWeightLoadingMode.FUSED_GATE_UP_PROJ-DefaultMoeRoutingMethod-1] SKIP (https://nvbugs/5775256)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[3-2] SKIP (https://nvbugs/5777041)
|
||||
unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py::test_ep_shard[8-2] SKIP (https://nvbugs/5777041)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16] SKIP (https://nvbugs/5769890)
|
||||
disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-v3-8b-hf] SKIP (https://nvbugs/5769890,https://nvbugs/5748683)
|
||||
|
||||
@ -31,10 +31,6 @@ def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None:
|
||||
).to(device=device, dtype=torch.bfloat16)
|
||||
x = model.get_input(device=device, dtype=torch.bfloat16)
|
||||
|
||||
if world_size > num_experts:
|
||||
print(f"world_size {world_size} > num_experts {num_experts}, skipping test")
|
||||
return
|
||||
|
||||
def _get_expected_num_params(rank: int, world_size: int, num_p_og: int) -> int:
|
||||
if world_size <= 1:
|
||||
return num_p_og
|
||||
@ -141,9 +137,11 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) ->
|
||||
run_sharding_pattern_detection_test(detected_transformations, expected_transformations)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_count", get_device_counts())
|
||||
@pytest.mark.parametrize("device_count", get_device_counts([2, 8]))
|
||||
@pytest.mark.parametrize("num_experts", [3, 8])
|
||||
def test_ep_shard(device_count: int, num_experts: int):
|
||||
if device_count > num_experts:
|
||||
pytest.skip(f"world_size {device_count} > num_experts {num_experts}")
|
||||
dist_common.spawn_multiprocess_job(
|
||||
job=partial(_run_ep_shard_job, num_experts),
|
||||
size=device_count,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user