muneebable commited on
Commit
aee71f9
·
verified ·
1 Parent(s): 29df5af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -10
app.py CHANGED
@@ -101,6 +101,23 @@ def gram_matrix(tensor):
101
 
102
  return gram
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
105
  content = load_image(content_image).to(device)
106
  style = load_image(style_image, shape=content.shape[-2:]).to(device)
@@ -124,6 +141,9 @@ def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, co
124
 
125
  optimizer = optim.Adam([target], lr=0.003)
126
 
 
 
 
127
  for ii in range(1, steps+1):
128
  target_features = get_features(target, vgg)
129
  content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
@@ -142,17 +162,15 @@ def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, co
142
  optimizer.zero_grad()
143
  total_loss.backward()
144
  optimizer.step()
 
 
 
145
 
146
- return im_convert(target)
147
-
 
 
148
 
149
- # Function to resize image while maintaining aspect ratio
150
- def resize_image(image_path, max_size=400):
151
- img = Image.open(image_path).convert('RGB')
152
- ratio = max_size / max(img.size)
153
- new_size = tuple([int(x*ratio) for x in img.size])
154
- img = img.resize(new_size, Image.Resampling.LANCZOS)
155
- return np.array(img)
156
 
157
  # Example images
158
  # examples = [
@@ -208,6 +226,7 @@ with gr.Blocks() as demo:
208
  steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps")
209
 
210
  run_button = gr.Button("Run Style Transfer")
 
211
 
212
  run_button.click(
213
  style_transfer,
@@ -223,7 +242,7 @@ with gr.Blocks() as demo:
223
  conv5_1_slider,
224
  steps_slider
225
  ],
226
- outputs=output_image
227
  )
228
 
229
  gr.Examples(
 
101
 
102
  return gram
103
 
104
+ # Function to resize image while maintaining aspect ratio
105
+ def resize_image(image_path, max_size=400):
106
+ img = Image.open(image_path).convert('RGB')
107
+ ratio = max_size / max(img.size)
108
+ new_size = tuple([int(x*ratio) for x in img.size])
109
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
110
+ return np.array(img)
111
+
112
+ def create_grid(images, rows, cols):
113
+ assert len(images) == rows * cols, "Number of images doesn't match the grid size"
114
+ w, h = images[0].shape[1], images[0].shape[0]
115
+ grid = np.zeros((h*rows, w*cols, 3), dtype=np.uint8)
116
+ for i, img in enumerate(images):
117
+ r, c = divmod(i, cols)
118
+ grid[r*h:(r+1)*h, c*w:(c+1)*w] = img
119
+ return grid
120
+
121
  def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
122
  content = load_image(content_image).to(device)
123
  style = load_image(style_image, shape=content.shape[-2:]).to(device)
 
141
 
142
  optimizer = optim.Adam([target], lr=0.003)
143
 
144
+ intermediate_images = []
145
+ show_every = steps // 9 # Show 9 intermediate images
146
+
147
  for ii in range(1, steps+1):
148
  target_features = get_features(target, vgg)
149
  content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
 
162
  optimizer.zero_grad()
163
  total_loss.backward()
164
  optimizer.step()
165
+
166
+ if ii % show_every == 0 or ii == steps:
167
+ intermediate_images.append(im_convert(target))
168
 
169
+ final_image = intermediate_images[-1]
170
+ intermediate_grid = create_grid(intermediate_images, 3, 3)
171
+
172
+ return final_image, intermediate_grid
173
 
 
 
 
 
 
 
 
174
 
175
  # Example images
176
  # examples = [
 
226
  steps_slider = gr.Slider(minimum=1, maximum=2000, value=1000, step=100, label="Number of Steps")
227
 
228
  run_button = gr.Button("Run Style Transfer")
229
+ intermediate_output = gr.Image(label="Intermediate Results")
230
 
231
  run_button.click(
232
  style_transfer,
 
242
  conv5_1_slider,
243
  steps_slider
244
  ],
245
+ outputs=[output_image, intermediate_output]
246
  )
247
 
248
  gr.Examples(