openedai-moderations / moderations2.py
taozi555's picture
Upload folder using huggingface_hub
62c1330 verified
from gevent import pywsgi
import dotenv
dotenv.load_dotenv(override=True)
import sys
import time
import argparse
import uvicorn
from typing import Union
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import openedai
import numpy as np
app = openedai.OpenAIStub()
moderation = None
device = "cuda" if torch.cuda.is_available() else "cpu"
#device = "cpu"
labels = ['hate',
'hate_threatening',
'harassment',
'harassment_threatening',
'self_harm',
'self_harm_intent',
'self_harm_instructions',
'sexual',
'sexual_minors',
'violence',
'violence_graphic',
]
label2id = {l:i for i, l in enumerate(labels)}
id2label = {i:l for i, l in enumerate(labels)}
model_name = "/root/autodl-tmp/duanyu027/moderation_0628"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification")
model.to(device)
#model = torch.quantization.quantize_dynamic(
# model, {torch.nn.Linear}, dtype=torch.qint8
#)
torch.set_num_threads(1)
class ModerationsRequest(BaseModel):
model: str = "text-moderation-latest" # or "text-moderation-stable"
input: Union[str, list[str]]
@app.post("/v1/moderations")
async def moderations(request: ModerationsRequest):
"""
Sample Response:
{
"id": "modr-XXXXX",
"model": "text-moderation-005",
"results": [
{
"flagged": true,
"categories": {
"sexual": false,
"hate": false,
"harassment": false,
"self-harm": false,
"sexual/minors": false,
"hate/threatening": false,
"violence/graphic": false,
"self-harm/intent": false,
"self-harm/instructions": false,
"harassment/threatening": true,
"violence": true,
},
"category_scores": {
"sexual": 1.2282071e-06,
"hate": 0.010696256,
"harassment": 0.29842457,
"self-harm": 1.5236925e-08,
"sexual/minors": 5.7246268e-08,
"hate/threatening": 0.0060676364,
"violence/graphic": 4.435014e-06,
"self-harm/intent": 8.098441e-10,
"self-harm/instructions": 2.8498655e-11,
"harassment/threatening": 0.63055265,
"violence": 0.99011886,
}
}
]
}
"""
# This function will handle the moderations request
# proxy requests to openai embeddings api, check for similarity with pre-saved embeddings
results = {
"id": f"modr-{int(time.time()*1e9)}",
"model": "text-moderation-005",
"results": [],
}
# input, string or array
if isinstance(request.input, str):
request.input = [request.input]
# 定义阈值
threshold = 0.5
# minor name adjustments
for text in request.input:
predictions = predict(text, model, tokenizer)
category_scores = {labels[i]: predictions[0][i].item() for i in range(len(labels))}
detect = {key: score > threshold for key, score in category_scores.items()}
detected = any(detect.values())
results['results'].extend([{
'flagged': detected,
'categories': detect,
'category_scores': category_scores,
}])
return results
def sigmoid(x):
return 1/(1 + np.exp(-x))
def parse_args(argv):
parser = argparse.ArgumentParser(description='Moderation API')
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument('--port', type=int, default=5002)
parser.add_argument('--test-load', action='store_true')
return parser.parse_args(argv)
def predict(text, model, tokenizer):
encoding = tokenizer.encode_plus(
text,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
#res = model(**input)
predictions = torch.sigmoid(outputs.logits) # Convert logits to probabilities
return predictions
# Main
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
# start API
print(f'Starting moderations[{device}] API on {args.host}:{args.port}', file=sys.stderr)
app.register_model('text-moderations-latest', 'text-moderations-stable')
app.register_model('text-moderations-005', 'text-moderations-ifmain')
if not args.test_load:
uvicorn.run(app, host=args.host, port=args.port)