Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from ultralytics import YOLO | |
from PIL import Image | |
import supervision as sv | |
import numpy as np | |
def yolov8_inference( | |
image, | |
selected_labels_list | |
): | |
""" | |
YOLOv8 inference function | |
Args: | |
image: Input image | |
model_path: Path to the model | |
image_size: Image size | |
conf_threshold: Confidence threshold | |
iou_threshold: IOU threshold | |
Returns: | |
Rendered image | |
""" | |
model = YOLO('erax_nsfw_v1.pt').to('cuda') | |
# set model parameters | |
model.overrides['conf'] = 0.3 # NMS confidence threshold | |
model.overrides['iou'] = 0.2 # NMS IoU threshold | |
model.overrides['agnostic_nms'] = False # NMS class-agnostic | |
model.overrides['max_det'] = 1000 # maximum number of detections per image | |
results = model([image]) | |
for result in results: | |
annotated_image = result.orig_img.copy() | |
h, w = annotated_image.shape[:2] | |
anchor = h if h > w else w | |
# Create the dictionary by filtering list1 and list2 based on list3 | |
selected_classes = [[0, 1, 2, 3, 4][["anus", "make_love", "nipple", "penis", "vagina"].index(item)] for item in selected_labels_list] | |
# print(filtered_mapping) | |
# selected_classes = [0, 1, 2, 3, 4] # all classes | |
detections = sv.Detections.from_ultralytics(result) | |
detections = detections[np.isin(detections.class_id, selected_classes)] | |
label_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, | |
text_position=sv.Position.CENTER, | |
text_scale=anchor/1700) | |
pixelate_annotator = sv.PixelateAnnotator(pixel_size=anchor/50) | |
annotated_image = pixelate_annotator.annotate( | |
scene=annotated_image.copy(), | |
detections=detections | |
) | |
annotated_image = label_annotator.annotate( | |
annotated_image, | |
detections=detections | |
) | |
return annotated_image[:, :, ::-1] | |
inputs = [ | |
gr.Image(type="filepath", label="Input Image"), | |
gr.CheckboxGroup(["anus", "make_love", "nipple", "penis", "vagina"], label="Input Labels"), | |
] | |
outputs = gr.Image(type="filepath", label="Output Image") | |
title = "EraX NSFW V1.0 Models for NSFW detection" | |
examples = [ | |
['demo/img_1.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \ | |
['demo/img_2.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]], \ | |
['demo/img_3.jpg', ["anus", "make_love", "nipple", "penis", "vagina"]] | |
] | |
demo_app = gr.Interface( | |
fn=yolov8_inference, | |
inputs=inputs, | |
outputs=outputs, | |
title=title, | |
examples=examples, | |
cache_examples=True, | |
) | |
demo_app.launch(debug=True) |