mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
fix: Unable to load phi4-model with tp_size>1 (#6093)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
c18b632160
commit
714f82b485
@ -126,8 +126,9 @@ def load_weights_from_hf_model(hf_model, config):
|
||||
if "qkv." in key:
|
||||
weights[key] = shuffle_qkv_weights(weights[key], config)
|
||||
|
||||
if config.architecture in ['Phi3SmallForCausalLM', "PhiMoEForCausalLM"
|
||||
] and config.mapping.has_tp():
|
||||
if config.architecture in [
|
||||
'Phi3SmallForCausalLM', "PhiMoEForCausalLM", "Phi3ForCausalLM"
|
||||
] and config.mapping.has_tp():
|
||||
weights = split_weights_tp(config, weights, torch_dtype)
|
||||
|
||||
return weights
|
||||
|
||||
@ -145,19 +145,23 @@ def split_weights_tp(config, weights, dtype):
|
||||
-1, hidden_size)
|
||||
split_weight = torch.cat(
|
||||
[split(x, tp_size, rank) for x in [q, k, v]], dim=0)
|
||||
|
||||
qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1)
|
||||
q = qkv_bias[:num_q_per_kv, :].reshape(-1)
|
||||
k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1)
|
||||
v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1)
|
||||
split_bias = torch.cat([split(x, tp_size, rank) for x in [q, k, v]],
|
||||
dim=0)
|
||||
if qkv_bias is not None:
|
||||
qkv_bias = qkv_bias.reshape(num_q_per_kv + 2, -1)
|
||||
q = qkv_bias[:num_q_per_kv, :].reshape(-1)
|
||||
k = qkv_bias[num_q_per_kv:num_q_per_kv + 1, :].reshape(-1)
|
||||
v = qkv_bias[num_q_per_kv + 1:num_q_per_kv + 2, :].reshape(-1)
|
||||
split_bias = torch.cat(
|
||||
[split(x, tp_size, rank) for x in [q, k, v]], dim=0)
|
||||
else:
|
||||
split_bias = None
|
||||
else:
|
||||
split_weight = split_qkv_tp(qkv_weight, num_heads, hidden_size,
|
||||
tp_size, rank)
|
||||
split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size,
|
||||
tp_size, rank)
|
||||
|
||||
if qkv_bias is not None:
|
||||
split_bias = split_qkv_bias_tp(qkv_bias, num_heads, hidden_size,
|
||||
tp_size, rank)
|
||||
else:
|
||||
split_bias = None
|
||||
weights.update(get_quant_weight(split_weight, prefix, split_bias))
|
||||
|
||||
prefix = layer_prefix + 'attention.dense'
|
||||
@ -171,7 +175,10 @@ def split_weights_tp(config, weights, dtype):
|
||||
mlp_fc_weight, mlp_fc_bias = get_weight_and_bias(
|
||||
weights, prefix, dtype)
|
||||
split_v = split_matrix_tp(mlp_fc_weight, tp_size, rank, dim=0)
|
||||
bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0)
|
||||
if mlp_fc_bias is not None:
|
||||
bias = split_matrix_tp(mlp_fc_bias, tp_size, rank, dim=0)
|
||||
else:
|
||||
bias = None
|
||||
weights.update(get_quant_weight(split_v, prefix, bias))
|
||||
else:
|
||||
mlp_fc_weight = get_weight(weights, prefix, dtype)
|
||||
|
||||
@ -175,6 +175,10 @@ nvidia/Nemotron-H-8B-Base-8K:
|
||||
accuracy: 69.180
|
||||
microsoft/Phi-4-mini-instruct:
|
||||
- accuracy: 68.98
|
||||
# Created a dummy accuracy to track tp_size=2 for phi4-mini model.
|
||||
# TODO: update once https://nvbugs/5393849 is fixed.
|
||||
microsoft/Phi-4-mini-instruct-tp2:
|
||||
- accuracy: 0.0
|
||||
nvidia/Llama-3_1-Nemotron-Ultra-253B-v1:
|
||||
- accuracy: 83.70
|
||||
- quant_algo: FP8
|
||||
|
||||
@ -366,6 +366,13 @@ class TestPhi4MiniInstruct(CliFlowAccuracyTestHarness):
|
||||
def test_auto_dtype(self):
|
||||
self.run(tasks=[MMLU(self.MODEL_NAME)], dtype='auto')
|
||||
|
||||
@pytest.mark.skip_less_device(2)
|
||||
def test_tp2(self):
|
||||
# Created a dummy accuracy to track tp_size=2 for phi4-mini model.
|
||||
# TODO: update once https://nvbugs/5393849 is fixed.
|
||||
MODEL_NAME = "microsoft/Phi-4-mini-instruct-tp2"
|
||||
self.run(tasks=[MMLU(MODEL_NAME)], tp_size=2)
|
||||
|
||||
|
||||
# Long sequence length test:
|
||||
# Model FP16 7B + 32K tokens in KV cache = 14 * 1024 MB + 32K * 0.5 MB = 30720 MB + scratch memory
|
||||
|
||||
@ -310,6 +310,7 @@ accuracy/test_cli_flow.py::TestPhi3Small8kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3Small128kInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi3_5MiniInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype
|
||||
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_multiblock_aggressive
|
||||
accuracy/test_cli_flow.py::TestMamba130M::test_auto_dtype
|
||||
|
||||
Loading…
Reference in New Issue
Block a user