hysts commited on
Commit
b2c2519
·
1 Parent(s): ca01e24

Use extracted lightweight models

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. model.py +60 -20
  3. requirements.txt +1 -0
app.py CHANGED
@@ -63,9 +63,10 @@ with gr.Blocks(css='style.css') as demo:
63
  create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES)
64
  with gr.TabItem('Scribble Interactive'):
65
  create_demo_scribble_interactive(
66
- model.process_scribble_interactive, max_images=MAX_IMAGES)
67
  with gr.TabItem('Fake Scribble'):
68
- create_demo_fake_scribble(model.process_fake_scribble, max_images=MAX_IMAGES)
 
69
  with gr.TabItem('Pose'):
70
  create_demo_pose(model.process_pose, max_images=MAX_IMAGES)
71
  with gr.TabItem('Segmentation'):
 
63
  create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES)
64
  with gr.TabItem('Scribble Interactive'):
65
  create_demo_scribble_interactive(
66
+ model.process_scribble_interactive, max_images=MAX_IMAGES)
67
  with gr.TabItem('Fake Scribble'):
68
+ create_demo_fake_scribble(model.process_fake_scribble,
69
+ max_images=MAX_IMAGES)
70
  with gr.TabItem('Pose'):
71
  create_demo_pose(model.process_pose, max_images=MAX_IMAGES)
72
  with gr.TabItem('Segmentation'):
model.py CHANGED
@@ -28,22 +28,36 @@ from cldm.model import create_model, load_state_dict
28
  from ldm.models.diffusion.ddim import DDIMSampler
29
  from share import *
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- class Model:
33
- WEIGHT_NAMES = {
34
- 'canny': 'control_sd15_canny.pth',
35
- 'hough': 'control_sd15_mlsd.pth',
36
- 'hed': 'control_sd15_hed.pth',
37
- 'scribble': 'control_sd15_scribble.pth',
38
- 'pose': 'control_sd15_openpose.pth',
39
- 'seg': 'control_sd15_seg.pth',
40
- 'depth': 'control_sd15_depth.pth',
41
- 'normal': 'control_sd15_normal.pth',
42
- }
43
 
 
44
  def __init__(self,
45
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
46
- model_dir: str = 'models'):
 
47
  self.device = torch.device(
48
  'cuda:0' if torch.cuda.is_available() else 'cpu')
49
  self.model = create_model(model_config_path).to(self.device)
@@ -51,31 +65,57 @@ class Model:
51
  self.task_name = ''
52
 
53
  self.model_dir = pathlib.Path(model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
54
  self.download_models()
55
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def load_weight(self, task_name: str) -> None:
57
  if task_name == self.task_name:
58
  return
59
  weight_path = self.get_weight_path(task_name)
60
- self.model.load_state_dict(
61
- load_state_dict(weight_path, location=self.device))
 
 
 
 
62
  self.task_name = task_name
63
 
64
  def get_weight_path(self, task_name: str) -> str:
65
  if 'scribble' in task_name:
66
  task_name = 'scribble'
67
- return f'{self.model_dir}/{self.WEIGHT_NAMES[task_name]}'
68
 
69
- def download_models(self):
70
  self.model_dir.mkdir(exist_ok=True, parents=True)
71
- for name in self.WEIGHT_NAMES.values():
72
  out_path = self.model_dir / name
73
  if out_path.exists():
74
  continue
75
  subprocess.run(
76
- shlex.split(
77
- f'wget https://huggingface.co/ckpt/ControlNet/resolve/main/{name} -O {out_path}'
78
- ))
79
 
80
  @torch.inference_mode()
81
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
 
28
  from ldm.models.diffusion.ddim import DDIMSampler
29
  from share import *
30
 
31
+ ORIGINAL_MODEL_NAMES = {
32
+ 'canny': 'control_sd15_canny.pth',
33
+ 'hough': 'control_sd15_mlsd.pth',
34
+ 'hed': 'control_sd15_hed.pth',
35
+ 'scribble': 'control_sd15_scribble.pth',
36
+ 'pose': 'control_sd15_openpose.pth',
37
+ 'seg': 'control_sd15_seg.pth',
38
+ 'depth': 'control_sd15_depth.pth',
39
+ 'normal': 'control_sd15_normal.pth',
40
+ }
41
+ ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/ckpt/ControlNet/resolve/main/'
42
+
43
+ LIGHTWEIGHT_MODEL_NAMES = {
44
+ 'canny': 'control_canny-fp16.safetensors',
45
+ 'hough': 'control_mlsd-fp16.safetensors',
46
+ 'hed': 'control_hed-fp16.safetensors',
47
+ 'scribble': 'control_scribble-fp16.safetensors',
48
+ 'pose': 'control_openpose-fp16.safetensors',
49
+ 'seg': 'control_seg-fp16.safetensors',
50
+ 'depth': 'control_depth-fp16.safetensors',
51
+ 'normal': 'control_normal-fp16.safetensors',
52
+ }
53
+ LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
54
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ class Model:
57
  def __init__(self,
58
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
+ model_dir: str = 'models',
60
+ use_lightweight: bool = True):
61
  self.device = torch.device(
62
  'cuda:0' if torch.cuda.is_available() else 'cpu')
63
  self.model = create_model(model_config_path).to(self.device)
 
65
  self.task_name = ''
66
 
67
  self.model_dir = pathlib.Path(model_dir)
68
+
69
+ self.use_lightweight = use_lightweight
70
+ if use_lightweight:
71
+ self.model_names = LIGHTWEIGHT_MODEL_NAMES
72
+ self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
73
+ base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
74
+ self.download_base_model(base_model_url)
75
+ base_model_path = self.model_dir / base_model_url.split('/')[-1]
76
+ self.load_base_model(base_model_path)
77
+ else:
78
+ self.model_names = ORIGINAL_MODEL_NAMES
79
+ self.weight_root = ORIGINAL_WEIGHT_ROOT
80
  self.download_models()
81
 
82
+ def download_base_model(self, base_model_url: str) -> None:
83
+ model_name = base_model_url.split('/')[-1]
84
+ out_path = self.model_dir / model_name
85
+ if out_path.exists():
86
+ return
87
+ subprocess.run(shlex.split(f'wget {base_model_url} -O {out_path}'))
88
+
89
+ def load_base_model(self, model_path: pathlib.Path) -> None:
90
+ self.model.load_state_dict(load_state_dict(model_path,
91
+ location=self.device.type),
92
+ strict=False)
93
+
94
  def load_weight(self, task_name: str) -> None:
95
  if task_name == self.task_name:
96
  return
97
  weight_path = self.get_weight_path(task_name)
98
+ if not self.use_lightweight:
99
+ self.model.load_state_dict(
100
+ load_state_dict(weight_path, location=self.device))
101
+ else:
102
+ self.model.control_model.load_state_dict(
103
+ load_state_dict(weight_path, location=self.device.type))
104
  self.task_name = task_name
105
 
106
  def get_weight_path(self, task_name: str) -> str:
107
  if 'scribble' in task_name:
108
  task_name = 'scribble'
109
+ return f'{self.model_dir}/{self.model_names[task_name]}'
110
 
111
+ def download_models(self) -> None:
112
  self.model_dir.mkdir(exist_ok=True, parents=True)
113
+ for name in self.model_names.values():
114
  out_path = self.model_dir / name
115
  if out_path.exists():
116
  continue
117
  subprocess.run(
118
+ shlex.split(f'wget {self.weight_root}{name} -O {out_path}'))
 
 
119
 
120
  @torch.inference_mode()
121
  def process_canny(self, input_image, prompt, a_prompt, n_prompt,
requirements.txt CHANGED
@@ -11,6 +11,7 @@ opencv-contrib-python==4.7.0.68
11
  opencv-python-headless==4.7.0.68
12
  prettytable==3.6.0
13
  pytorch-lightning==1.9.0
 
14
  timm==0.6.12
15
  torch==1.13.1
16
  torchvision==0.14.1
 
11
  opencv-python-headless==4.7.0.68
12
  prettytable==3.6.0
13
  pytorch-lightning==1.9.0
14
+ safetensors==0.2.8
15
  timm==0.6.12
16
  torch==1.13.1
17
  torchvision==0.14.1