Vivien commited on
Commit
fd7f424
·
1 Parent(s): 3370440

Switch to transformers

Browse files
Files changed (4) hide show
  1. .gitignore +0 -1
  2. app.py +31 -57
  3. packages.txt +0 -1
  4. requirements.txt +2 -3
.gitignore DELETED
@@ -1 +0,0 @@
1
- result.jpg
 
 
app.py CHANGED
@@ -1,18 +1,9 @@
1
  import numpy as np
2
- import PIL
 
3
  import torch
4
  import streamlit as st
5
 
6
- device = torch.device("cpu")
7
-
8
- DEBUG = False
9
- if DEBUG:
10
- cache_kwargs = {"max_entries": 30}
11
- model_name = "MiDaS_small"
12
- else:
13
- cache_kwargs = {"show_spinner": False, "max_entries": 30}
14
- model_name = "DPT_Large"
15
-
16
  FONTS = [
17
  "Font: Serif - EBGaramond",
18
  "Font: Serif - Cinzel",
@@ -35,41 +26,36 @@ def hex_to_rgb(hex):
35
  return tuple(rgb)
36
 
37
 
38
- @st.experimental_singleton
39
- def load(model_type):
40
- midas = torch.hub.load("intel-isl/MiDaS", model_type)
41
- midas.to(device)
42
- _ = midas.eval()
43
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
44
- if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
45
- transform = midas_transforms.dpt_transform
46
- else:
47
- transform = midas_transforms.small_transform
48
- return midas, transform
49
 
50
 
51
- midas, transform = load(model_name)
52
 
53
 
54
- @st.experimental_memo(**cache_kwargs)
55
- def compute_depth(img):
56
  with torch.no_grad():
57
- prediction = midas(transform(img).to(device))
58
- prediction = torch.nn.functional.interpolate(
59
- prediction.unsqueeze(1),
60
- size=img.shape[:2],
61
- mode="bicubic",
62
- align_corners=False,
63
- ).squeeze()
64
- return prediction.cpu().numpy()
 
65
 
66
 
67
  def get_mask1(
68
  shape, x, y, caption, font=None, font_size=0.08, color=(0, 0, 0), alpha=0.8
69
  ):
70
- img_text = PIL.Image.new("RGBA", (shape[1], shape[0]), (0, 0, 0, 0))
71
- draw = PIL.ImageDraw.Draw(img_text)
72
- font = PIL.ImageFont.truetype(font, int(font_size * shape[1]))
73
  draw.text(
74
  (x * shape[1], (1 - y) * shape[0]),
75
  caption,
@@ -78,8 +64,7 @@ def get_mask1(
78
  )
79
  text = np.array(img_text)
80
  mask1 = np.dot(np.expand_dims(text[:, :, -1] / 255, -1), np.ones((1, 3)))
81
- text = text[:, :, :-1]
82
- return text, mask1
83
 
84
 
85
  def get_mask2(depth_map, depth):
@@ -101,8 +86,6 @@ def add_caption(
101
  font="",
102
  alpha=1,
103
  ):
104
- if depth_map is None:
105
- depth_map = compute_depth(img)
106
  text, mask1 = get_mask1(
107
  img.shape,
108
  x,
@@ -119,13 +102,13 @@ def add_caption(
119
  return ((1 - mask) * img + mask * text).astype(np.uint8)
120
 
121
 
122
- @st.experimental_memo(**cache_kwargs)
123
  def load_img(uploaded_file):
124
  if uploaded_file is None:
125
- img = PIL.Image.open("pulp.jpg")
126
  default = True
127
  else:
128
- img = PIL.Image.open(uploaded_file)
129
  if img.size[0] > 800 or img.size[1] > 800:
130
  if img.size[0] < img.size[1]:
131
  new_size = (int(800 * img.size[0] / img.size[1]), 800)
@@ -133,7 +116,7 @@ def load_img(uploaded_file):
133
  new_size = (800, int(800 * img.size[1] / img.size[0]))
134
  img = img.resize(new_size)
135
  default = False
136
- return np.array(img), default
137
 
138
 
139
  def main():
@@ -163,13 +146,11 @@ def main():
163
  )
164
 
165
  uploaded_file = st.file_uploader("", type=["jpg", "jpeg"])
166
-
167
- img, default = load_img(uploaded_file)
168
-
169
- del uploaded_file
170
 
171
  if default:
172
- x0, y0, alpha0, font_size0, depth0, font0 = 0.02, 0.68, 0.99, 0.07, 0.23, 4
173
  text0 = "Pulp Fiction"
174
  else:
175
  x0, y0, alpha0, font_size0, depth0, font0 = 0.1, 0.9, 0.8, 0.08, 0.5, 0
@@ -243,6 +224,7 @@ def main():
243
  x=x,
244
  y=y,
245
  depth=depth,
 
246
  font=font,
247
  font_size=font_size,
248
  alpha=alpha,
@@ -251,14 +233,6 @@ def main():
251
 
252
  st.image(captioned)
253
 
254
- # PIL.Image.fromarray(captioned).save("result.jpg")
255
- # with open("result.jpg", "rb") as file:
256
- # btn = st.download_button(
257
- # label="Download image", data=file, file_name="result.jpg", mime="image/jpeg"
258
- # )
259
-
260
- del captioned, img
261
-
262
 
263
  if __name__ == "__main__":
264
  main()
 
1
  import numpy as np
2
+ from PIL import ImageDraw, Image, ImageFont
3
+ from transformers import DPTFeatureExtractor, DPTForDepthEstimation
4
  import torch
5
  import streamlit as st
6
 
 
 
 
 
 
 
 
 
 
 
7
  FONTS = [
8
  "Font: Serif - EBGaramond",
9
  "Font: Serif - Cinzel",
 
26
  return tuple(rgb)
27
 
28
 
29
+ @st.cache(allow_output_mutation=True)
30
+ def load():
31
+ feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
32
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
33
+ return model, feature_extractor
 
 
 
 
 
 
34
 
35
 
36
+ model, feature_extractor = load()
37
 
38
 
39
+ def compute_depth(image):
40
+ inputs = feature_extractor(images=image, return_tensors="pt")
41
  with torch.no_grad():
42
+ outputs = model(**inputs)
43
+ predicted_depth = outputs.predicted_depth
44
+ prediction = torch.nn.functional.interpolate(
45
+ predicted_depth.unsqueeze(1),
46
+ size=image.size[::-1],
47
+ mode="bicubic",
48
+ align_corners=False,
49
+ )
50
+ return prediction.cpu().numpy()[0, 0, :, :]
51
 
52
 
53
  def get_mask1(
54
  shape, x, y, caption, font=None, font_size=0.08, color=(0, 0, 0), alpha=0.8
55
  ):
56
+ img_text = Image.new("RGBA", (shape[1], shape[0]), (0, 0, 0, 0))
57
+ draw = ImageDraw.Draw(img_text)
58
+ font = ImageFont.truetype(font, int(font_size * shape[1]))
59
  draw.text(
60
  (x * shape[1], (1 - y) * shape[0]),
61
  caption,
 
64
  )
65
  text = np.array(img_text)
66
  mask1 = np.dot(np.expand_dims(text[:, :, -1] / 255, -1), np.ones((1, 3)))
67
+ return text[:, :, :-1], mask1
 
68
 
69
 
70
  def get_mask2(depth_map, depth):
 
86
  font="",
87
  alpha=1,
88
  ):
 
 
89
  text, mask1 = get_mask1(
90
  img.shape,
91
  x,
 
102
  return ((1 - mask) * img + mask * text).astype(np.uint8)
103
 
104
 
105
+ @st.cache(max_entries=30, show_spinner=False)
106
  def load_img(uploaded_file):
107
  if uploaded_file is None:
108
+ img = Image.open("pulp.jpg")
109
  default = True
110
  else:
111
+ img = Image.open(uploaded_file)
112
  if img.size[0] > 800 or img.size[1] > 800:
113
  if img.size[0] < img.size[1]:
114
  new_size = (int(800 * img.size[0] / img.size[1]), 800)
 
116
  new_size = (800, int(800 * img.size[1] / img.size[0]))
117
  img = img.resize(new_size)
118
  default = False
119
+ return np.array(img), compute_depth(img), default
120
 
121
 
122
  def main():
 
146
  )
147
 
148
  uploaded_file = st.file_uploader("", type=["jpg", "jpeg"])
149
+ with st.spinner("Analyzing the image - Please wait a few seconds"):
150
+ img, depth_map, default = load_img(uploaded_file)
 
 
151
 
152
  if default:
153
+ x0, y0, alpha0, font_size0, depth0, font0 = 0.02, 0.68, 0.99, 0.07, 0.12, 4
154
  text0 = "Pulp Fiction"
155
  else:
156
  x0, y0, alpha0, font_size0, depth0, font0 = 0.1, 0.9, 0.8, 0.08, 0.5, 0
 
224
  x=x,
225
  y=y,
226
  depth=depth,
227
+ depth_map=depth_map,
228
  font=font,
229
  font_size=font_size,
230
  alpha=alpha,
 
233
 
234
  st.image(captioned)
235
 
 
 
 
 
 
 
 
 
236
 
237
  if __name__ == "__main__":
238
  main()
packages.txt DELETED
@@ -1 +0,0 @@
1
- libgl1
 
 
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  numpy
2
  torch
3
- timm
4
- pillow
5
- opencv-python
 
1
  numpy
2
  torch
3
+ transformers
4
+ pillow