Spaces:
Build error
Build error
File size: 4,305 Bytes
9042918 751f591 9042918 3cf1717 9042918 98410be 9042918 1fa5d2c 9042918 1fa5d2c 9042918 98410be 9042918 1fa5d2c 98410be 1fa5d2c 98410be 9042918 98410be 1fa5d2c 9042918 1fa5d2c 9042918 1fa5d2c 9042918 1fa5d2c 9042918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
from __future__ import annotations
import os
import pathlib
import shlex
import shutil
import subprocess
import gradio as gr
import PIL.Image
import torch
os.environ['PYTHONPATH'] = f'custom-diffusion:{os.getenv("PYTHONPATH", "")}'
def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
w, h = image.size
if w == h:
return image
elif w > h:
new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
class Trainer:
def __init__(self):
self.is_running = False
self.is_running_message = 'Another training is in progress.'
self.output_dir = pathlib.Path('results')
self.instance_data_dir = self.output_dir / 'training_data'
self.class_data_dir = self.output_dir / 'regularization_data'
def check_if_running(self) -> dict:
if self.is_running:
return gr.update(value=self.is_running_message)
else:
return gr.update(value='No training is running.')
def cleanup_dirs(self) -> None:
shutil.rmtree(self.output_dir, ignore_errors=True)
def prepare_dataset(self, concept_images: list, resolution: int) -> None:
self.instance_data_dir.mkdir(parents=True)
for i, temp_path in enumerate(concept_images):
image = PIL.Image.open(temp_path.name)
image = pad_image(image)
image = image.resize((resolution, resolution))
image = image.convert('RGB')
out_path = self.instance_data_dir / f'{i:03d}.jpg'
image.save(out_path, format='JPEG', quality=100)
def run(
self,
base_model: str,
resolution_s: str,
concept_images: list | None,
concept_prompt: str,
class_prompt: str,
n_steps: int,
learning_rate: float,
train_text_encoder: bool,
modifier_token: bool,
gradient_accumulation: int,
batch_size: int,
use_8bit_adam: bool,
) -> tuple[dict, list[pathlib.Path]]:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if self.is_running:
return gr.update(value=self.is_running_message), []
if concept_images is None:
raise gr.Error('You need to upload images.')
if not concept_prompt:
raise gr.Error('The concept prompt is missing.')
resolution = int(resolution_s)
self.cleanup_dirs()
self.prepare_dataset(concept_images, resolution)
command = f'''
accelerate launch custom-diffusion/src/diffuser_training.py \
--pretrained_model_name_or_path={base_model} \
--instance_data_dir={self.instance_data_dir} \
--output_dir={self.output_dir} \
--instance_prompt="{concept_prompt}" \
--class_data_dir={self.class_data_dir} \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="{class_prompt}" \
--resolution={resolution} \
--train_batch_size={batch_size} \
--gradient_accumulation_steps={gradient_accumulation} \
--learning_rate={learning_rate} \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps={n_steps} \
--num_class_images=200 \
--scale_lr
'''
if modifier_token:
command += ' --modifier_token "<new1>"'
if use_8bit_adam:
command += ' --use_8bit_adam'
if train_text_encoder:
command += f' --train_text_encoder'
with open(self.output_dir / 'train.sh', 'w') as f:
command_s = ' '.join(command.split())
f.write(command_s)
self.is_running = True
res = subprocess.run(shlex.split(command))
self.is_running = False
if res.returncode == 0:
result_message = 'Training Completed!'
else:
result_message = 'Training Failed!'
weight_paths = sorted(self.output_dir.glob('*.bin'))
return gr.update(value=result_message), weight_paths
|