Pass LoRA rank to LoRALinearLayer (#2191)
This commit is contained in:
@@ -296,10 +296,10 @@ class LoRACrossAttnProcessor(nn.Module):
|
|||||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||||
@@ -408,10 +408,10 @@ class LoRAXFormersCrossAttnProcessor(nn.Module):
|
|||||||
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
||||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||||
|
|||||||
Reference in New Issue
Block a user