mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix:https://nvbugs/5234033 enable starcoder trt-flow with transforme… (#3909)
fix:https://nvbugs/5234033 enable startcoder trt-flow with transformer 4.51.3. Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
parent
5dc3b539ba
commit
e76cf9d9fe
@ -459,9 +459,9 @@ def load_weights_from_hf_model(hf_model,
|
||||
f'{prefix}.self_attn.k_proj', dtype)
|
||||
v_w, v_b = get_weight_and_bias(model_params,
|
||||
f'{prefix}.self_attn.v_proj', dtype)
|
||||
qkv_w = torch.cat([q_w, k_w, v_w], dim=0)
|
||||
qkv_b = torch.cat([q_b, k_b, v_b],
|
||||
dim=0) if q_b is not None else None
|
||||
qkv_w = torch.cat([q_w.cuda(), k_w.cuda(), v_w.cuda()], dim=0)
|
||||
qkv_b = torch.cat([q_b.cuda(), k_b.cuda(),
|
||||
v_b.cuda()], dim=0) if q_b is not None else None
|
||||
elif gpt_variant == 'persimmon':
|
||||
qkv_w, qkv_b = get_weight_and_bias(
|
||||
model_params, f'{prefix}.self_attn.query_key_value', dtype)
|
||||
|
||||
@ -265,6 +265,9 @@ class Parameter:
|
||||
def _regularize_value(self, value):
|
||||
if isinstance(value, np.ndarray):
|
||||
return value
|
||||
|
||||
elif isinstance(value, torch.distributed.tensor.DTensor):
|
||||
return value.to_local().cpu().numpy()
|
||||
elif isinstance(value, torch.Tensor):
|
||||
return torch_to_numpy(value)
|
||||
raise TypeError(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user