From 714f82b485fd63e5945abe8148563045a5522c87 Mon Sep 17 00:00:00 2001 From: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:16:36 +0800 Subject: [PATCH] fix: Unable to load phi4-model with tp_size>1 (#6093) Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com> --- tensorrt_llm/models/phi3/convert.py | 5 ++-- tensorrt_llm/models/phi3/split_weights.py | 29 ++++++++++++------- .../defs/accuracy/references/mmlu.yaml | 4 +++ .../defs/accuracy/test_cli_flow.py | 7 +++++ .../test_lists/qa/examples_test_list.txt | 1 + 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/models/phi3/convert.py b/tensorrt_llm/models/phi3/convert.py index ddd654dadd..1f11408b69 100644 --- a/tensorrt_llm/models/phi3/convert.py +++ b/tensorrt_llm/models/phi3/convert.py @@ -126,8 +126,9 @@ def load_weights_from_hf_model(hf_model, config): if "qkv." in key: weights[key] = shuffle_qkv_weights(weights[key], config) - if config.architecture in ['Phi3SmallForCausalLM', "PhiMoEForCausalLM" - ] and config.mapping.has_tp(): + if config.architecture in [ + 'Phi3SmallForCausalLM', "PhiMoEForCausalLM", "Phi3ForCausalLM" + ] and config.mapping.has_tp(): weights = split_weights_tp(config, weights, torch_dtype) return weights diff --git a/tensorrt_llm/models/phi3/split_weights.py b/tensorrt_llm/models/phi3/split_weights.py index 62a8891230..bca33ba551 100644 --- a/tensorrt_llm/models/phi3/split_weights.py +++ b/tensorrt_llm/models/phi3/split_weights.py @@ -145,19 +145,23 @@ def split_weights_tp(config, weights, dtype): -1, hidden_size) split_weight = torch.cat( [split(x, tp_size, rank) for x in [q, k, v]], dim=0) - - qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1) - q = qkv_bias[:num_q_per_kv, :].reshape(-1) - k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1) - v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1) - split_bias = torch.cat([split(x, tp_size, rank) for x in [q, k, v]], - dim=0) + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1) + q = qkv_bias[:num_q_per_kv, :].reshape(-1) + k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1) + v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1) + split_bias = torch.cat( + [split(x, tp_size, rank) for x in [q, k, v]], dim=0) + else: + split_bias = None else: split_weight = split_qkv_tp(qkv_weight, num_heads, hidden_size, tp_size, rank) - split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size, - tp_size, rank) - + if qkv_bias is not None: + split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size, + tp_size, rank) + else: + split_bias = None weights.update(get_quant_weight(split_weight, prefix, split_bias)) prefix = layer_prefix + 'attention.dense' @@ -171,7 +175,10 @@ def split_weights_tp(config, weights, dtype): mlp_fc_weight, mlp_fc_bias = get_weight_and_bias( weights, prefix, dtype) split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0) - bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0) + if mlp_fc_bias is not None: + bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0) + else: + bias = None weights.update(get_quant_weight(split_v, prefix, bias)) else: mlp_fc_weight = get_weight(weights, prefix, dtype) diff --git a/tests/integration/defs/accuracy/references/mmlu.yaml b/tests/integration/defs/accuracy/references/mmlu.yaml index b6feeee376..c9583dfc82 100644 --- a/tests/integration/defs/accuracy/references/mmlu.yaml +++ b/tests/integration/defs/accuracy/references/mmlu.yaml @@ -175,6 +175,10 @@ nvidia/Nemotron-H-8B-Base-8K: accuracy: 69.180 microsoft/Phi-4-mini-instruct: - accuracy: 68.98 +# Created a dummy accuracy to track tp_size=2 for phi4-mini model. +# TODO: update once https://nvbugs/5393849 is fixed. +microsoft/Phi-4-mini-instruct-tp2: + - accuracy: 0.0 nvidia/Llama-3_1-Nemotron-Ultra-253B-v1: - accuracy: 83.70 - quant_algo: FP8 diff --git a/tests/integration/defs/accuracy/test_cli_flow.py b/tests/integration/defs/accuracy/test_cli_flow.py index 133c99d811..647341dc11 100644 --- a/tests/integration/defs/accuracy/test_cli_flow.py +++ b/tests/integration/defs/accuracy/test_cli_flow.py @@ -366,6 +366,13 @@ class TestPhi4MiniInstruct(CliFlowAccuracyTestHarness): def test_auto_dtype(self): self.run(tasks=[MMLU(self.MODEL_NAME)], dtype='auto') + @pytest.mark.skip_less_device(2) + def test_tp2(self): + # Created a dummy accuracy to track tp_size=2 for phi4-mini model. + # TODO: update once https://nvbugs/5393849 is fixed. + MODEL_NAME = "microsoft/Phi-4-mini-instruct-tp2" + self.run(tasks=[MMLU(MODEL_NAME)], tp_size=2) + # Long sequence length test: # Model FP16 7B + 32K tokens in KV cache = 14 * 1024 MB + 32K * 0.5 MB = 30720 MB + scratch memory diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index 1c073d91a5..3ce6314fec 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -310,6 +310,7 @@ accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype +accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype