Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 62ece6ab5a | |||
| 7722f5b67c | |||
| 0281e85827 | |||
| 9106e66382 | |||
| 2fb3d141c2 | |||
| 6e1b06c01c | |||
| 50f2544697 | |||
| fc5fc8c8d2 | |||
| 030bb528ba | |||
| f803d3d1f5 | |||
| 3758d7a8b0 | |||
| 3a794b54c9 | |||
| fe623f3bea | |||
| bc65f829b7 | |||
| c22be1a557 | |||
| 05f716d4ac | |||
| 25b0d5b8c4 |
@@ -755,17 +755,26 @@ def main(args):
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
in_features=attn_module.to_q.in_features,
|
||||
out_features=attn_module.to_q.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
in_features=attn_module.to_k.in_features,
|
||||
out_features=attn_module.to_k.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
in_features=attn_module.to_v.in_features,
|
||||
out_features=attn_module.to_v.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
@@ -773,6 +782,7 @@ def main(args):
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -207,6 +207,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# print("After first norm")
|
||||
# print(f"hidden_states: {hidden_states.dtype}")
|
||||
# print(f"norm_hidden_states: {norm_hidden_states.dtype}")
|
||||
# print(f"encoder_hidden_states: {norm_hidden_states.dtype}")
|
||||
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
@@ -223,7 +227,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
hidden_states = attn_output + hidden_states
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
|
||||
@@ -84,6 +84,7 @@ class LoRALinearLayer(nn.Module):
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# print(f"From {self.__class__.__name__}: hidden_states: {hidden_states.dtype}")
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -93,7 +94,9 @@ class LoRALinearLayer(nn.Module):
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
out = up_hidden_states.to(orig_dtype)
|
||||
# print(f"From {self.__class__.__name__}: out: {out.dtype}")
|
||||
return out
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user