muneebable's picture
Update app.py
ac18872 verified
raw
history blame
4.1 kB
import gradio as gr
import torch
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import requests
from io import BytesIO
#Load VGG19 model
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
param.requires_grad_(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
# Helper functions (load_image, im_convert, get_features, gram_matrix)
# ... (Include the helper functions you provided earlier here)
def style_transfer(content_image, style_image, alpha, beta, conv1_1, conv2_1, conv3_1, conv4_1, conv5_1, steps):
content = load_image(content_image).to(device)
style = load_image(style_image, shape=content.shape[-2:]).to(device)
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}
target = content.clone().requires_grad_(True).to(device)
style_weights = {
'conv1_1': conv1_1,
'conv2_1': conv2_1,
'conv3_1': conv3_1,
'conv4_1': conv4_1,
'conv5_1': conv5_1
}
content_weight = alpha
style_weight = beta * 1e6
optimizer = optim.Adam([target], lr=0.003)
for ii in range(1, steps+1):
target_features = get_features(target, vgg)
content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)
style_loss = 0
for layer in style_weights:
target_feature = target_features[layer]
target_gram = gram_matrix(target_feature)
_, d, h, w = target_feature.shape
style_gram = style_grams[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram)**2)
style_loss += layer_style_loss / (d * h * w)
total_loss = content_weight * content_loss + style_weight * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return im_convert(target)
# Example images
examples = [
["path/to/content1.jpg", "path/to/style1.jpg"],
["path/to/content2.jpg", "path/to/style2.jpg"],
["path/to/content3.jpg", "path/to/style3.jpg"],
]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Neural Style Transfer")
with gr.Row():
with gr.Column():
content_input = gr.Image(label="Content Image")
style_input = gr.Image(label="Style Image")
with gr.Column():
output_image = gr.Image(label="Output Image")
with gr.Row():
alpha_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Content Weight (α)")
beta_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Style Weight (β)")
with gr.Row():
conv1_1_slider = gr.Slider(minimum=0, maximum=1, value=1, step=0.1, label="Conv1_1 Weight")
conv2_1_slider = gr.Slider(minimum=0, maximum=1, value=0.8, step=0.1, label="Conv2_1 Weight")
conv3_1_slider = gr.Slider(minimum=0, maximum=1, value=0.5, step=0.1, label="Conv3_1 Weight")
conv4_1_slider = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.1, label="Conv4_1 Weight")
conv5_1_slider = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="Conv5_1 Weight")
steps_slider = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, label="Number of Steps")
run_button = gr.Button("Run Style Transfer")
run_button.click(
style_transfer,
inputs=[
content_input,
style_input,
alpha_slider,
beta_slider,
conv1_1_slider,
conv2_1_slider,
conv3_1_slider,
conv4_1_slider,
conv5_1_slider,
steps_slider
],
outputs=output_image
)
gr.Examples(
examples,
inputs=[content_input, style_input]
)
demo.launch()