mjbuehler commited on
Commit
1fb3d9f
·
verified ·
1 Parent(s): 83323ae

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +111 -0
README.md CHANGED
@@ -148,4 +148,115 @@ grid
148
 
149
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/R7sr9kAwZjRk_80oMY54h.png)
150
 
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  ![image/png](https://cdn-uploads.huggingface.co/production/uploads/623ce1c6b66fedf374859fe7/R7sr9kAwZjRk_80oMY54h.png)
150
 
151
+ ## Fine-tuning script
152
 
153
+ Download this script: [SDXL DreamBooth-LoRA_Fine-Tune.ipynb](https://huggingface.co/lamm-mit/SDXL-leaf-inspired/resolve/main/SDXL_DreamBooth_LoRA_Fine-Tune.ipynb)
154
+
155
+ You need to create a local folder ```leaf_concept_dir_SDXL``` and add the leaf images (provided in this repository, see subfolder).
156
+
157
+ The code will automatically download the training script.
158
+
159
+ The training script can handle custom prompts associated with each image, which are generated using BLIP.
160
+
161
+ For instance, for the images used here, they are:
162
+
163
+ ```raw
164
+ ['<leaf microstructure>, a close up of a green plant with a lot of small holes',
165
+ '<leaf microstructure>, a close up of a leaf with a small insect on it',
166
+ '<leaf microstructure>, a close up of a plant with a lot of green leaves',
167
+ '<leaf microstructure>, a close up of a green plant with a yellow light',
168
+ '<leaf microstructure>, a close up of a green plant with a white center',
169
+ '<leaf microstructure>, arafed leaf with a white line on the center',
170
+ '<leaf microstructure>, a close up of a leaf with a yellow light shining through it',
171
+ '<leaf microstructure>, arafed image of a green plant with a yellow cross']
172
+ ```
173
+
174
+ Training then proceeds as:
175
+
176
+ ```python
177
+ HF_username = 'lamm-mit'
178
+
179
+ pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0"
180
+ pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix"
181
+
182
+ instance_prompt ="<leaf microstructure>"
183
+ instance_data_dir = "./leaf_concept_dir_SDXL/"
184
+
185
+ val_prompt = "a vase that resembles a <leaf microstructure>, high quality"
186
+ val_epochs = 100
187
+
188
+ instance_output_dir="leaf_LoRA_SDXL_V10" #for checkpointing
189
+ ```
190
+
191
+ Dataset generatio with custom per-image captions
192
+ ```python
193
+ import requests
194
+ from transformers import AutoProcessor, BlipForConditionalGeneration
195
+ import torch
196
+ import glob
197
+ from PIL import Image
198
+ import json
199
+
200
+ device = "cuda" if torch.cuda.is_available() else "cpu"
201
+
202
+ # load the processor and the captioning model
203
+ blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
204
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",torch_dtype=torch.float16).to(device)
205
+
206
+ # captioning utility
207
+ def caption_images(input_image):
208
+ inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
209
+ pixel_values = inputs.pixel_values
210
+
211
+ generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
212
+ generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
213
+ return generated_caption
214
+
215
+ caption_prefix = f"{instance_prompt}, "
216
+ with open(f'{instance_data_dir}metadata.jsonl', 'w') as outfile:
217
+ for img in imgs_and_paths:
218
+ caption = caption_prefix + caption_images(img[1]).split("\n")[0]
219
+ entry = {"file_name":img[0].split("/")[-1], "prompt": caption}
220
+ json.dump(entry, outfile)
221
+ outfile.write('\n')
222
+ ```
223
+ This produces a JSON file in the ```instance_data_dir``` directory:
224
+
225
+ ```json
226
+ {"file_name": "0.jpeg", "prompt": "<leaf microstructure>, a close up of a green plant with a lot of small holes"}
227
+ {"file_name": "1.jpeg", "prompt": "<leaf microstructure>, a close up of a leaf with a small insect on it"}
228
+ {"file_name": "2.jpeg", "prompt": "<leaf microstructure>, a close up of a plant with a lot of green leaves"}
229
+ {"file_name": "3.jpeg", "prompt": "<leaf microstructure>, a close up of a leaf with a yellow substance in it"}
230
+ {"file_name": "87.jpg", "prompt": "<leaf microstructure>, a close up of a green plant with a yellow light"}
231
+ {"file_name": "88.jpg", "prompt": "<leaf microstructure>, a close up of a green plant with a white center"}
232
+ {"file_name": "90.jpg", "prompt": "<leaf microstructure>, arafed leaf with a white line on the center"}
233
+ {"file_name": "91.jpg", "prompt": "<leaf microstructure>, arafed image of a green leaf with a white spot"}
234
+ {"file_name": "92.jpg", "prompt": "<leaf microstructure>, a close up of a leaf with a yellow light shining through it"}
235
+ {"file_name": "94.jpg", "prompt": "<leaf microstructure>, arafed image of a green plant with a yellow cross"}
236
+ ```
237
+
238
+ ```raw
239
+ !accelerate launch train_dreambooth_lora_sdxl.py \
240
+ --pretrained_model_name_or_path="{pretrained_model_name_or_path}" \
241
+ --pretrained_vae_model_name_or_path="{pretrained_vae_model_name_or_path}"\
242
+ --dataset_name="{instance_data_dir}" \
243
+ --output_dir="{instance_output_dir}" \
244
+ --caption_column="prompt"\
245
+ --mixed_precision="fp16" \
246
+ --instance_prompt="{instance_prompt}" \
247
+ --validation_prompt="{val_prompt}" \
248
+ --validation_epochs="{val_epochs}" \
249
+ --resolution=1024 \
250
+ --train_batch_size=1 \
251
+ --gradient_accumulation_steps=3 \
252
+ --gradient_checkpointing \
253
+ --learning_rate=1e-4 \
254
+ --snr_gamma=5.0 \
255
+ --lr_scheduler="constant" \
256
+ --lr_warmup_steps=0 \
257
+ --mixed_precision="fp16" \
258
+ --use_8bit_adam \
259
+ --max_train_steps=500 \
260
+ --checkpointing_steps=500 \
261
+ --seed="0"
262
+ ```