fix: Unable to load phi4-model with tp_size>1 (#6093)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-07-18 10:16:36 +08:00 committed by GitHub
parent c18b632160
commit 714f82b485
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 13 deletions

View File

@ -126,8 +126,9 @@ def load_weights_from_hf_model(hf_model, config):
if "qkv." in key:
weights[key] = shuffle_qkv_weights(weights[key], config)
if config.architecture in ['Phi3SmallForCausalLM', "PhiMoEForCausalLM"
] and config.mapping.has_tp():
if config.architecture in [
'Phi3SmallForCausalLM', "PhiMoEForCausalLM", "Phi3ForCausalLM"
] and config.mapping.has_tp():
weights = split_weights_tp(config, weights, torch_dtype)
return weights

View File

@ -145,19 +145,23 @@ def split_weights_tp(config, weights, dtype):
-1, hidden_size)
split_weight = torch.cat(
[split(x, tp_size, rank) for x in [q, k, v]], dim=0)
qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1)
q = qkv_bias[:num_q_per_kv, :].reshape(-1)
k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1)
v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1)
split_bias = torch.cat([split(x, tp_size, rank) for x in [q, k, v]],
dim=0)
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1)
q = qkv_bias[:num_q_per_kv, :].reshape(-1)
k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1)
v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1)
split_bias = torch.cat(
[split(x, tp_size, rank) for x in [q, k, v]], dim=0)
else:
split_bias = None
else:
split_weight = split_qkv_tp(qkv_weight, num_heads, hidden_size,
tp_size, rank)
split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size,
tp_size, rank)
if qkv_bias is not None:
split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size,
tp_size, rank)
else:
split_bias = None
weights.update(get_quant_weight(split_weight, prefix, split_bias))
prefix = layer_prefix + 'attention.dense'
@ -171,7 +175,10 @@ def split_weights_tp(config, weights, dtype):
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
weights, prefix, dtype)
split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0)
bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0)
if mlp_fc_bias is not None:
bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0)
else:
bias = None
weights.update(get_quant_weight(split_v, prefix, bias))
else:
mlp_fc_weight = get_weight(weights, prefix, dtype)

View File

@ -175,6 +175,10 @@ nvidia/Nemotron-H-8B-Base-8K:
accuracy: 69.180
microsoft/Phi-4-mini-instruct:
- accuracy: 68.98
# Created a dummy accuracy to track tp_size=2 for phi4-mini model.
# TODO: update once https://nvbugs/5393849 is fixed.
microsoft/Phi-4-mini-instruct-tp2:
- accuracy: 0.0
nvidia/Llama-3_1-Nemotron-Ultra-253B-v1:
- accuracy: 83.70
- quant_algo: FP8

View File

@ -366,6 +366,13 @@ class TestPhi4MiniInstruct(CliFlowAccuracyTestHarness):
def test_auto_dtype(self):
self.run(tasks=[MMLU(self.MODEL_NAME)], dtype='auto')
@pytest.mark.skip_less_device(2)
def test_tp2(self):
# Created a dummy accuracy to track tp_size=2 for phi4-mini model.
# TODO: update once https://nvbugs/5393849 is fixed.
MODEL_NAME = "microsoft/Phi-4-mini-instruct-tp2"
self.run(tasks=[MMLU(MODEL_NAME)], tp_size=2)
# Long sequence length test:
# Model FP16 7B + 32K tokens in KV cache = 14 * 1024 MB + 32K * 0.5 MB = 30720 MB + scratch memory

View File

@ -310,6 +310,7 @@ accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype
accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype
accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive
accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype