Nupur Kumari commited on
Commit
71da51f
·
1 Parent(s): 2ab48ae

custom-diffusion-space

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. inference.py +22 -21
app.py CHANGED
@@ -99,9 +99,10 @@ def create_training_demo(trainer: Trainer,
99
  use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
100
  gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
101
  gr.Markdown('''
102
- - Only enable one of "Train Text Encoder" or "modifier token" or None.
103
- - It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
104
  - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
 
105
  - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
106
  ''')
107
 
 
99
  use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
100
  gradient_checkpointing = gr.Checkbox(label='Enable gradient checkpointing', value=False)
101
  gr.Markdown('''
102
+ - It will take about ~10 minutes to train for 1000 steps and ~21GB on a 3090 GPU.
103
+ - Our results in the paper are with the above batch-size of 2 and 2 GPUs.
104
  - Enable gradient checkpointing for lower memory requirements (~14GB) at the expense of slower backward pass.
105
+ - If "Train Text Encoder", disable "modifier token".
106
  - Note that your trained models will be deleted when the second training is started. You can upload your trained model in the "Upload" tab.
107
  ''')
108
 
inference.py CHANGED
@@ -12,26 +12,27 @@ import torch
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, 'custom-diffusion')
14
 
15
-
16
- def load_model(text_encoder, tokenizer, unet, save_path, modifier_token, freeze_model='crossattn_kv'):
17
- st = torch.load(save_path)
18
- if 'text_encoder' in st:
19
- text_encoder.load_state_dict(st['text_encoder'])
20
- if modifier_token in st:
21
- _ = tokenizer.add_tokens(modifier_token)
22
- modifier_token_id = tokenizer.convert_tokens_to_ids(modifier_token)
23
- # Resize the token embeddings as we are adding new special tokens to the tokenizer
24
- text_encoder.resize_token_embeddings(len(tokenizer))
25
- token_embeds = text_encoder.get_input_embeddings().weight.data
26
- token_embeds[modifier_token_id] = st[modifier_token]
27
- print(st.keys())
28
- for name, params in unet.named_parameters():
29
- if freeze_model == 'crossattn':
30
- if 'attn2' in name:
31
- params.data.copy_(st['unet'][f'{name}'])
32
- else:
33
- if 'attn2.to_k' in name or 'attn2.to_v' in name:
34
- params.data.copy_(st['unet'][f'{name}'])
 
35
 
36
 
37
  class InferencePipeline:
@@ -67,7 +68,7 @@ class InferencePipeline:
67
  model_id, torch_dtype=torch.float16)
68
  pipe = pipe.to(self.device)
69
 
70
- load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
71
 
72
  self.pipe = pipe
73
 
 
12
  from diffusers import StableDiffusionPipeline
13
  sys.path.insert(0, 'custom-diffusion')
14
 
15
+ from sys import diffuser_training
16
+
17
+ # def load_model(text_encoder, tokenizer, unet, save_path, modifier_token, freeze_model='crossattn_kv'):
18
+ # st = torch.load(save_path)
19
+ # if 'text_encoder' in st:
20
+ # text_encoder.load_state_dict(st['text_encoder'])
21
+ # if modifier_token in st:
22
+ # _ = tokenizer.add_tokens(modifier_token)
23
+ # modifier_token_id = tokenizer.convert_tokens_to_ids(modifier_token)
24
+ # # Resize the token embeddings as we are adding new special tokens to the tokenizer
25
+ # text_encoder.resize_token_embeddings(len(tokenizer))
26
+ # token_embeds = text_encoder.get_input_embeddings().weight.data
27
+ # token_embeds[modifier_token_id] = st[modifier_token]
28
+ # print(st.keys())
29
+ # for name, params in unet.named_parameters():
30
+ # if freeze_model == 'crossattn':
31
+ # if 'attn2' in name:
32
+ # params.data.copy_(st['unet'][f'{name}'])
33
+ # else:
34
+ # if 'attn2.to_k' in name or 'attn2.to_v' in name:
35
+ # params.data.copy_(st['unet'][f'{name}'])
36
 
37
 
38
  class InferencePipeline:
 
68
  model_id, torch_dtype=torch.float16)
69
  pipe = pipe.to(self.device)
70
 
71
+ diffuser_training.load_model(pipe.text_encoder, pipe.tokenizer, pipe.unet, weight_path, '<new1>')
72
 
73
  self.pipe = pipe
74