|
from gevent import pywsgi |
|
import dotenv |
|
dotenv.load_dotenv(override=True) |
|
|
|
import sys |
|
import time |
|
import argparse |
|
import uvicorn |
|
from typing import Union |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import openedai |
|
import numpy as np |
|
|
|
app = openedai.OpenAIStub() |
|
moderation = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
labels = ['hate', |
|
'hate_threatening', |
|
'harassment', |
|
'harassment_threatening', |
|
'self_harm', |
|
'self_harm_intent', |
|
'self_harm_instructions', |
|
'sexual', |
|
'sexual_minors', |
|
'violence', |
|
'violence_graphic', |
|
] |
|
|
|
label2id = {l:i for i, l in enumerate(labels)} |
|
id2label = {i:l for i, l in enumerate(labels)} |
|
model_name = "/root/autodl-tmp/duanyu027/moderation_0628" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification") |
|
model.to(device) |
|
|
|
|
|
|
|
torch.set_num_threads(1) |
|
class ModerationsRequest(BaseModel): |
|
model: str = "text-moderation-latest" |
|
input: Union[str, list[str]] |
|
|
|
@app.post("/v1/moderations") |
|
async def moderations(request: ModerationsRequest): |
|
""" |
|
Sample Response: |
|
{ |
|
"id": "modr-XXXXX", |
|
"model": "text-moderation-005", |
|
"results": [ |
|
{ |
|
"flagged": true, |
|
"categories": { |
|
"sexual": false, |
|
"hate": false, |
|
"harassment": false, |
|
"self-harm": false, |
|
"sexual/minors": false, |
|
"hate/threatening": false, |
|
"violence/graphic": false, |
|
"self-harm/intent": false, |
|
"self-harm/instructions": false, |
|
"harassment/threatening": true, |
|
"violence": true, |
|
}, |
|
"category_scores": { |
|
"sexual": 1.2282071e-06, |
|
"hate": 0.010696256, |
|
"harassment": 0.29842457, |
|
"self-harm": 1.5236925e-08, |
|
"sexual/minors": 5.7246268e-08, |
|
"hate/threatening": 0.0060676364, |
|
"violence/graphic": 4.435014e-06, |
|
"self-harm/intent": 8.098441e-10, |
|
"self-harm/instructions": 2.8498655e-11, |
|
"harassment/threatening": 0.63055265, |
|
"violence": 0.99011886, |
|
} |
|
} |
|
] |
|
} |
|
""" |
|
|
|
|
|
results = { |
|
"id": f"modr-{int(time.time()*1e9)}", |
|
"model": "text-moderation-005", |
|
"results": [], |
|
} |
|
|
|
|
|
if isinstance(request.input, str): |
|
request.input = [request.input] |
|
|
|
threshold = 0.5 |
|
|
|
for text in request.input: |
|
predictions = predict(text, model, tokenizer) |
|
category_scores = {labels[i]: predictions[0][i].item() for i in range(len(labels))} |
|
detect = {key: score > threshold for key, score in category_scores.items()} |
|
detected = any(detect.values()) |
|
|
|
results['results'].extend([{ |
|
'flagged': detected, |
|
'categories': detect, |
|
'category_scores': category_scores, |
|
}]) |
|
return results |
|
def sigmoid(x): |
|
return 1/(1 + np.exp(-x)) |
|
|
|
def parse_args(argv): |
|
parser = argparse.ArgumentParser(description='Moderation API') |
|
parser.add_argument('--host', type=str, default='0.0.0.0') |
|
parser.add_argument('--port', type=int, default=5002) |
|
parser.add_argument('--test-load', action='store_true') |
|
return parser.parse_args(argv) |
|
|
|
def predict(text, model, tokenizer): |
|
encoding = tokenizer.encode_plus( |
|
text, |
|
return_tensors='pt' |
|
) |
|
input_ids = encoding['input_ids'].to(device) |
|
attention_mask = encoding['attention_mask'].to(device) |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
|
|
predictions = torch.sigmoid(outputs.logits) |
|
return predictions |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_args(sys.argv[1:]) |
|
|
|
print(f'Starting moderations[{device}] API on {args.host}:{args.port}', file=sys.stderr) |
|
app.register_model('text-moderations-latest', 'text-moderations-stable') |
|
app.register_model('text-moderations-005', 'text-moderations-ifmain') |
|
|
|
if not args.test_load: |
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|