Spaces:
Running
Running
File size: 1,402 Bytes
2eac672 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
from typing import Optional
BLOCKS = {
'content': ['unet.up_blocks.0.attentions.0'],
'style': ['unet.up_blocks.0.attentions.1'],
}
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f'failed to is_belong_to_block, due to: {e}')
def filter_lora(state_dict, blocks_):
try:
return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
except Exception as e:
raise type(e)(f'failed to filter_lora, due to: {e}')
def scale_lora(state_dict, alpha):
try:
return {k: v * alpha for k, v in state_dict.items()}
except Exception as e:
raise type(e)(f'failed to scale_lora, due to: {e}')
def get_target_modules(unet, blocks=None):
try:
if not blocks:
blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
is_belong_to_blocks(attn_processor_name, blocks)]
target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
return target_modules
except Exception as e:
raise type(e)(f'failed to get_target_modules, due to: {e}')
|