clip_aes / app.py
haor's picture
Update app.py
b5fd63d verified
raw
history blame
4.33 kB
import gradio as gr
import torch
import torch.nn as nn
import clip
import pandas as pd
import hashlib
import numpy as np
import cv2
import time
from PIL import Image
# MLP model definition
class MLP(nn.Module):
def __init__(self, input_size):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(input_size, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
# Convert binary array to hexadecimal string
def binary_array_to_hex(arr):
bit_string = ''.join(str(b) for b in 1 * arr.flatten())
width = int(np.ceil(len(bit_string) / 4))
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
# Calculate perceptual hash of an image
def phash(image, hash_size=8, highfreq_factor=4):
if hash_size < 2:
raise ValueError('Hash size must be greater than or equal to 2')
import scipy.fftpack
img_size = hash_size * highfreq_factor
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
pixels = np.asarray(image)
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return binary_array_to_hex(diff)
# Convert NumPy types to Python built-in types
def convert_numpy_types(data):
if isinstance(data, dict):
return {key: convert_numpy_types(value) for key, value in data.items()}
elif isinstance(data, list):
return [convert_numpy_types(item) for item in data]
elif isinstance(data, np.float64):
return float(data)
elif isinstance(data, np.int64):
return int(data)
else:
return data
# Normalize tensor
def normalize(a, axis=-1, order=2):
l2 = torch.linalg.norm(a, dim=axis, ord=order, keepdim=True)
l2[l2 == 0] = 1
return a / l2
# Load pre-trained MLP model and CLIP model
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
model.to(device).eval()
model2, preprocess = clip.load("ViT-L/14", device=device)
# Predict aesthetic score and other metrics of an image
def predict(image):
# Preprocess image
image = Image.fromarray(image)
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
phash_value = phash(image)
md5 = hashlib.md5(image.tobytes()).hexdigest()
sha1 = hashlib.sha1(image.tobytes()).hexdigest()
inputs = preprocess(image).unsqueeze(0).to(device)
with torch.no_grad():
# Extract image features using CLIP model
start_time = time.time()
img_emb = model2.encode_image(inputs)
end_time = time.time()
print(f"Encoding image took {end_time - start_time} seconds")
# Normalize image features
start_time = time.time()
img_emb = normalize(img_emb).float()
end_time = time.time()
print(f"Normalizing image took {end_time - start_time} seconds")
# Predict aesthetic score using MLP model
start_time = time.time()
prediction = model(img_emb).item()
end_time = time.time()
print(f"Making prediction took {end_time - start_time} seconds")
# Return prediction results
result = {
"clip_aesthetic": prediction,
"phash": phash_value,
"md5": md5,
"sha1": sha1,
"laplacian_variance": laplacian_variance
}
return convert_numpy_types(result)
# Create web interface using Gradio
title = "CLIP Aesthetic Score"
description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics."
gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.JSON(label="Result"),
title=title,
description=description,
examples=[["example1.jpg"], ["example2.jpg"]]
).launch()