EraX-NSFW-V1.0 / app.py
erax's picture
Update app.py
751cc0c verified
import gradio as gr
import spaces
import torch
from ultralytics import YOLO
from PIL import Image
import supervision as sv
import numpy as np
@spaces.GPU
def yolov8_inference(
image,
selected_labels_list
):
"""
YOLOv8 inference function
Args:
image: Input image
model_path: Path to the model
image_size: Image size
conf_threshold: Confidence threshold
iou_threshold: IOU threshold
Returns:
Rendered image
"""
model = YOLO('erax_nsfw_v1.pt').to('cuda')
# set model parameters
model.overrides['conf'] = 0.3 # NMS confidence threshold
model.overrides['iou'] = 0.2 # NMS IoU threshold
model.overrides['agnostic_nms'] = False # NMS class-agnostic
model.overrides['max_det'] = 1000 # maximum number of detections per image
results = model([image])
for result in results:
annotated_image = result.orig_img.copy()
h, w = annotated_image.shape[:2]
anchor = h if h > w else w
# Create the dictionary by filtering list1 and list2 based on list3
selected_classes = [[0, 1, 2, 3, 4][["anus", "make_love", "nipple", "penis", "vagina"].index(item)] for item in selected_labels_list]
# print(filtered_mapping)
# selected_classes = [0, 1, 2, 3, 4] # all classes
detections = sv.Detections.from_ultralytics(result)
detections = detections[np.isin(detections.class_id, selected_classes)]
label_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK,
text_position=sv.Position.CENTER,
text_scale=anchor/1700)
pixelate_annotator = sv.PixelateAnnotator(pixel_size=anchor/50)
annotated_image = pixelate_annotator.annotate(
scene=annotated_image.copy(),
detections=detections
)
annotated_image = label_annotator.annotate(
annotated_image,
detections=detections
)
return annotated_image[:, :, ::-1]
inputs = [
gr.Image(type="filepath", label="Input Image"),
gr.CheckboxGroup(["anus", "make_love", "nipple", "penis", "vagina"], label="Input Labels"),
]
outputs = gr.Image(type="filepath", label="Output Image")
title = "EraX NSFW V1.0 Models for NSFW detection"
examples = [
['demo/img_1.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \
['demo/img_2.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \
['demo/img_3.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]]
]
demo_app = gr.Interface(
fn=yolov8_inference,
inputs=inputs,
outputs=outputs,
title=title,
examples=examples,
cache_examples=True,
)
demo_app.launch(debug=True)