Spaces:
Running
Running
from typing import Union, List | |
import requests | |
import uvicorn | |
from fastapi import BackgroundTasks, FastAPI | |
import img_label | |
from img_nsfw import init_nsfw_pipe, check_nsfw | |
import model | |
app = FastAPI() | |
def write_scan_img_result(image_id: int, scans: List[int], img: str, callback: str): | |
score_general_threshold = 0.35 | |
score_character_threshold = 0.85 | |
nsfw_tags = [] | |
img_tags = [] | |
if 0 in scans: | |
nsfw_tags = check_nsfw(img, pipe) | |
if 1 in scans: | |
img_tags = img_label.label_img( | |
image=img, | |
model="SwinV2", | |
l_score_general_threshold=score_general_threshold, | |
l_score_character_threshold=score_character_threshold, | |
)['general_res'] | |
print(nsfw_tags) | |
print(img_tags) | |
img_tags = list(map(lambda x: model.ImageTag(tag=x['tag'], confidence=x['confidence']), img_tags)) | |
callBackReq = model.ImageScanCallbackRequest(id=image_id, isValid=True, tags=img_tags) | |
try: | |
requests.post(callback, json=callBackReq.dict()) | |
except Exception as ex: | |
print(ex) | |
nsfw_tags = map(lambda x: model.ImageScanTag(type="Moderation", confidence=x['confidence']), nsfw_tags) | |
ret = model.ImageScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=nsfw_tags) | |
return ret | |
def write_scan_model_result(model_name: str, callback: str): | |
pass | |
# @app.post("/model-scan") | |
# async def send_notification(email: str, background_tasks: BackgroundTasks): | |
# background_tasks.add_task(write_scan_model_result, email, callback="") | |
# return {"message": "Notification sent in the background"} | |
async def image_scan_handler(req: model.ImageScanRequest, background_tasks: BackgroundTasks): | |
if not req.wait: | |
background_tasks.add_task(write_scan_img_result, | |
image_id=req.imageId, | |
scans=req.scans, | |
img=req.url, callback=req.callbackUrl) | |
return model.ImageScanResponse(ok=True, error="", deleted=False, blockedFor=[], tags=[]) | |
else: | |
ret = write_scan_img_result(image_id=req.imageId, | |
scans=req.scans, | |
img=req.url, callback=req.callbackUrl) | |
return ret | |
if __name__ == "__main__": | |
global pipe | |
pipe = init_nsfw_pipe() | |
uvicorn.run(app, host="0.0.0.0", port=6006) | |