glass_try_on1 / app.py
Siyun He
add save button
215339c
raw
history blame
5.78 kB
import cv2
import cvzone
import numpy as np
import os
import gradio as gr
from datetime import datetime
# Load the YuNet model
model_path = 'face_detection_yunet_2023mar.onnx'
face_detector = cv2.FaceDetectorYN.create(model_path, "", (320, 320))
# Initialize the glass number
num = 1
overlay = cv2.imread(f'glasses/glass{num}.png', cv2.IMREAD_UNCHANGED)
# Count glasses files
def count_files_in_directory(directory):
file_count = 0
for root, dirs, files in os.walk(directory):
file_count += len(files)
return file_count
directory_path = 'glasses'
total_glass_num = count_files_in_directory(directory_path)
# Change glasses
def change_glasses():
global num, overlay
num += 1
if num > total_glass_num:
num = 1
overlay = cv2.imread(f'glasses/glass{num}.png', cv2.IMREAD_UNCHANGED)
return overlay
# Process frame for overlay
def process_frame(frame):
global overlay
# Ensure the frame is writable
frame = np.array(frame, copy=True)
height, width = frame.shape[:2]
face_detector.setInputSize((width, height))
_, faces = face_detector.detect(frame)
if faces is not None:
for face in faces:
x, y, w, h = face[:4].astype(int)
face_landmarks = face[4:14].reshape(5, 2).astype(int) # Facial landmarks
# Get the nose position
nose_x, nose_y = face_landmarks[2].astype(int)
# Left and right eye positions
left_eye_x, left_eye_y = face_landmarks[0].astype(int)
right_eye_x, right_eye_y = face_landmarks[1].astype(int)
# Calculate the midpoint between the eyes
eye_center_x = (left_eye_x + right_eye_x) // 2
eye_center_y = (left_eye_y + right_eye_y) // 2
# Calculate the angle of rotation
delta_x = right_eye_x - left_eye_x
delta_y = right_eye_y - left_eye_y
angle = np.degrees(np.arctan2(delta_y, delta_x))
# Negate the angle to rotate in the opposite direction
angle = -angle
# Resize the overlay
overlay_resize = cv2.resize(overlay, (int(w * 1.15), int(h * 0.8)))
# Rotate the overlay
overlay_center = (overlay_resize.shape[1] // 2, overlay_resize.shape[0] // 2)
rotation_matrix = cv2.getRotationMatrix2D(overlay_center, angle, 1.0)
overlay_rotated = cv2.warpAffine(
overlay_resize, rotation_matrix,
(overlay_resize.shape[1], overlay_resize.shape[0]),
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0, 0)
)
# Calculate the position to center the glasses on the eyes
overlay_x = eye_center_x - overlay_rotated.shape[1] // 2
overlay_y = eye_center_y - overlay_rotated.shape[0] // 2
# Overlay the glasses
try:
frame = cvzone.overlayPNG(frame, overlay_rotated, [overlay_x, overlay_y])
except Exception as e:
print(f"Error overlaying glasses: {e}")
return frame
# Transform function
def transform_cv2(frame, transform):
if transform == "cartoon":
# prepare color
img_color = cv2.pyrDown(cv2.pyrDown(frame))
for _ in range(6):
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
img_color = cv2.pyrUp(cv2.pyrUp(img_color))
# prepare edges
img_edges = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
img_edges = cv2.adaptiveThreshold(
cv2.medianBlur(img_edges, 7),
255,
cv2.ADAPTIVE_THRESH_MEAN_C,
cv2.THRESH_BINARY,
9,
2,
)
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
# combine color and edges
img = cv2.bitwise_and(img_color, img_edges)
return img
elif transform == "edges":
# perform edge detection
img = cv2.cvtColor(cv2.Canny(frame, 100, 200), cv2.COLOR_GRAY2BGR)
return img
else:
return frame
def refresh_interface():
# # Reset the transformation dropdown to its default value
# transform.update(value="none")
# Reset the image to an empty state or a default image
input_img.update(value=None)
# Return a message indicating the interface has been refreshed
return "Interface refreshed!"
def save_frame(frame):
# Convert frame to RGB
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Create a unique filename using the current timestamp
filename = f"saved_frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
# Save the frame
cv2.imwrite(filename, frame)
# Refresh the interfaceq
refresh_interface()
return f"Frame saved as '{filename}'"
# Gradio webcam input
def webcam_input(frame, transform):
frame = process_frame(frame)
frame = transform_cv2(frame, transform)
return frame
# Gradio Interface
with gr.Blocks() as demo:
with gr.Column(elem_classes=["my-column"]):
with gr.Group(elem_classes=["my-group"]):
transform = gr.Dropdown(choices=["cartoon", "edges", "none"],
value="none", label="Transformation")
input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True)
next_button = gr.Button("Next Glasses")
save_button = gr.Button("Save as a Picture")
input_img.stream(webcam_input, [input_img, transform], [input_img], time_limit=30, stream_every=0.1)
with gr.Row():
next_button.click(change_glasses, [], [])
with gr.Row():
save_button.click(save_frame, [input_img], [])
if __name__ == "__main__":
demo.launch(share=True)