muneebable commited on
Commit
4135cc9
·
verified ·
1 Parent(s): c249e39

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.optim as optim
4
+ import torchvision.models as models
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import numpy as np
8
+ import requests
9
+ from io import BytesIO
10
+
11
+ # Load VGG19 model
12
+ vgg = models.vgg19(pretrained=True).features
13
+ for param in vgg.parameters():
14
+ param.requires_grad_(False)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ vgg.to(device)
17
+
18
+ # Helper functions (load_image, im_convert, get_features, gram_matrix)
19
+ # ... (Include the helper functions you provided earlier here)
20
+
21
+ def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
22
+ content = load_image(content_image).to(device)
23
+ style = load_image(style_image, shape=content.shape[-2:]).to(device)
24
+
25
+ content_features = get_features(content, vgg)
26
+ style_features = get_features(style, vgg)
27
+ style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
28
+
29
+ target = content.clone().requires_grad_(True).to(device)
30
+
31
+ style_weights = {
32
+ 'conv1_1': conv1_1,
33
+ 'conv2_1': conv2_1,
34
+ 'conv3_1': conv3_1,
35
+ 'conv4_1': conv4_1,
36
+ 'conv5_1': conv5_1
37
+ }
38
+
39
+ content_weight = alpha
40
+ style_weight = beta * 1e6
41
+
42
+ optimizer = optim.Adam([target], lr=0.003)
43
+
44
+ for ii in range(1, steps+1):
45
+ target_features = get_features(target, vgg)
46
+ content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
47
+
48
+ style_loss = 0
49
+ for layer in style_weights:
50
+ target_feature = target_features[layer]
51
+ target_gram = gram_matrix(target_feature)
52
+ _, d, h, w = target_feature.shape
53
+ style_gram = style_grams[layer]
54
+ layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
55
+ style_loss += layer_style_loss / (d * h * w)
56
+
57
+ total_loss = content_weight * content_loss + style_weight * style_loss
58
+
59
+ optimizer.zero_grad()
60
+ total_loss.backward()
61
+ optimizer.step()
62
+
63
+ return im_convert(target)
64
+
65
+ # Example images
66
+ examples = [
67
+ ["path/to/content1.jpg", "path/to/style1.jpg"],
68
+ ["path/to/content2.jpg", "path/to/style2.jpg"],
69
+ ["path/to/content3.jpg", "path/to/style3.jpg"],
70
+ ]
71
+
72
+ # Gradio interface
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("# Neural Style Transfer")
75
+ with gr.Row():
76
+ with gr.Column():
77
+ content_input = gr.Image(label="Content Image")
78
+ style_input = gr.Image(label="Style Image")
79
+ with gr.Column():
80
+ output_image = gr.Image(label="Output Image")
81
+
82
+ with gr.Row():
83
+ alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
84
+ beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
85
+
86
+ with gr.Row():
87
+ conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
88
+ conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
89
+ conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
90
+ conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
91
+ conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
92
+
93
+ steps_slider = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Number of Steps")
94
+
95
+ run_button = gr.Button("Run Style Transfer")
96
+
97
+ run_button.click(
98
+ style_transfer,
99
+ inputs=[
100
+ content_input,
101
+ style_input,
102
+ alpha_slider,
103
+ beta_slider,
104
+ conv1_1_slider,
105
+ conv2_1_slider,
106
+ conv3_1_slider,
107
+ conv4_1_slider,
108
+ conv5_1_slider,
109
+ steps_slider
110
+ ],
111
+ outputs=output_image
112
+ )
113
+
114
+ gr.Examples(
115
+ examples,
116
+ inputs=[content_input, style_input]
117
+ )
118
+
119
+ demo.launch()