Spaces:
Build error
Build error
import torch | |
import clip | |
from PIL import Image | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load the CLIP model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
def extract_features_cp(pil_img: Image.Image) -> np.ndarray: | |
# Preprocess the PIL image using CLIP's preprocess function | |
img = preprocess(pil_img).unsqueeze(0).to(device) | |
# Extract features using CLIP | |
with torch.no_grad(): | |
features = model.encode_image(img) | |
# Normalize the features | |
features = features / features.norm(dim=-1, keepdim=True) | |
# Convert to numpy array and return as a flattened array | |
return features.cpu().numpy().flatten() | |
def extract_features(img_path): | |
# Load and preprocess the image | |
img = preprocess(Image.open(img_path)).unsqueeze(0).to(device) | |
# Extract features using CLIP | |
with torch.no_grad(): | |
features = model.encode_image(img) | |
# Normalize the features | |
features = features / features.norm(dim=-1, keepdim=True) | |
# Convert to numpy array | |
return features.cpu().numpy().flatten() | |
def compare_features(features1, features2): | |
# Cosine similarity | |
cos_sim = cosine_similarity([features1], [features2])[0][0] | |
return cos_sim | |
def predict_similarity(features1, features2, threshold=0.5): | |
cos_sim = compare_features(features1, features2) | |
similarity_score = cos_sim | |
return similarity_score > threshold | |
if __name__ == '__main__': | |
# Example usage | |
img_path1 = 'result.jpg' | |
img_path2 = 'Vochysia.jpg' | |
# Extract features | |
features1 = extract_features(img_path1) | |
features2 = extract_features(img_path2) | |
# Compare features | |
cos_sim = compare_features(features1, features2) | |
print(f'Cosine Similarity: {cos_sim}') | |
# Predict similarity | |
is_similar = predict_similarity(features1, features2, threshold=0.8) | |
print(f'Are the images similar? {"Yes" if is_similar else "No"}') | |