fix:https://nvbugs/5234033 enable starcoder trt-flow with transforme… (#3909)

fix:https://nvbugs/5234033 enable startcoder trt-flow with transformer 4.51.3.

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
nv-guomingz 2025-05-15 11:16:45 +08:00 committed by GitHub
parent 5dc3b539ba
commit e76cf9d9fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View File

@ -459,9 +459,9 @@ def load_weights_from_hf_model(hf_model,
f'{prefix}.self_attn.k_proj', dtype)
v_w, v_b = get_weight_and_bias(model_params,
f'{prefix}.self_attn.v_proj', dtype)
qkv_w = torch.cat([q_w, k_w, v_w], dim=0)
qkv_b = torch.cat([q_b, k_b, v_b],
dim=0) if q_b is not None else None
qkv_w = torch.cat([q_w.cuda(), k_w.cuda(), v_w.cuda()], dim=0)
qkv_b = torch.cat([q_b.cuda(), k_b.cuda(),
v_b.cuda()], dim=0) if q_b is not None else None
elif gpt_variant == 'persimmon':
qkv_w, qkv_b = get_weight_and_bias(
model_params, f'{prefix}.self_attn.query_key_value', dtype)

View File

@ -265,6 +265,9 @@ class Parameter:
def _regularize_value(self, value):
if isinstance(value, np.ndarray):
return value
elif isinstance(value, torch.distributed.tensor.DTensor):
return value.to_local().cpu().numpy()
elif isinstance(value, torch.Tensor):
return torch_to_numpy(value)
raise TypeError(