File size: 4,102 Bytes
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
ac18872
c1a650d
 
 
 
 
4135cc9
c1a650d
 
4135cc9
c1a650d
 
 
4135cc9
c1a650d
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
 
 
 
 
4135cc9
c1a650d
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
4135cc9
c1a650d
 
 
 
 
 
 
 
4135cc9
c1a650d
4135cc9
c1a650d
 
 
4135cc9
c1a650d
4135cc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO

#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)

# Helper functions (load_image, im_convert, get_features, gram_matrix)
# ... (Include the helper functions you provided earlier here)

def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
    content = load_image(content_image).to(device)
    style = load_image(style_image, shape=content.shape[-2:]).to(device)
    
    content_features = get_features(content, vgg)
    style_features = get_features(style, vgg)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
    
    target = content.clone().requires_grad_(True).to(device)
    
    style_weights = {
        'conv1_1': conv1_1,
        'conv2_1': conv2_1,
        'conv3_1': conv3_1,
        'conv4_1': conv4_1,
        'conv5_1': conv5_1
    }
    
    content_weight = alpha
    style_weight = beta * 1e6
    
    optimizer = optim.Adam([target], lr=0.003)
    
    for ii in range(1, steps+1):
        target_features = get_features(target, vgg)
        content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
        
        style_loss = 0
        for layer in style_weights:
            target_feature = target_features[layer]
            target_gram = gram_matrix(target_feature)
            _, d, h, w = target_feature.shape
            style_gram = style_grams[layer]
            layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
            style_loss += layer_style_loss / (d * h * w)
        
        total_loss = content_weight * content_loss + style_weight * style_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
    
    return im_convert(target)

# Example images
examples = [
    ["path/to/content1.jpg", "path/to/style1.jpg"],
    ["path/to/content2.jpg", "path/to/style2.jpg"],
    ["path/to/content3.jpg", "path/to/style3.jpg"],
]

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Neural Style Transfer")
    with gr.Row():
        with gr.Column():
            content_input = gr.Image(label="Content Image")
            style_input = gr.Image(label="Style Image")
        with gr.Column():
            output_image = gr.Image(label="Output Image")
    
    with gr.Row():
        alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
        beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
    
    with gr.Row():
        conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
        conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
        conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
        conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
        conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
    
    steps_slider = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Number of Steps")
    
    run_button = gr.Button("Run Style Transfer")
    
    run_button.click(
        style_transfer,
        inputs=[
            content_input,
            style_input,
            alpha_slider,
            beta_slider,
            conv1_1_slider,
            conv2_1_slider,
            conv3_1_slider,
            conv4_1_slider,
            conv5_1_slider,
            steps_slider
        ],
        outputs=output_image
    )
    
    gr.Examples(
        examples,
        inputs=[content_input, style_input]
    )

demo.launch()