|
from gevent import pywsgi |
|
import sys |
|
import time |
|
import argparse |
|
import uvicorn |
|
from typing import Union |
|
from pydantic import BaseModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
import os |
|
import openedai |
|
import numpy as np |
|
import asyncio |
|
from urllib.parse import urlparse |
|
import nacos |
|
import configparser |
|
|
|
|
|
|
|
app = openedai.OpenAIStub() |
|
moderation = None |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
labels = ['hate', |
|
'hate_threatening', |
|
'harassment', |
|
'harassment_threatening', |
|
'self_harm', |
|
'self_harm_intent', |
|
'self_harm_instructions', |
|
'sexual', |
|
'sexual_minors', |
|
'violence', |
|
'violence_graphic', |
|
] |
|
|
|
label2id = {l:i for i, l in enumerate(labels)} |
|
id2label = {i:l for i, l in enumerate(labels)} |
|
model_name = "duanyu027/moderation_0628" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(labels),id2label=id2label, label2id=label2id, problem_type = "multi_label_classification") |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
torch.set_num_threads(1) |
|
def register_service(client,service_name,service_ip,service_port,cluster_name,health_check_interval,weight,http_proxy,domain,protocol,direct_domain): |
|
try: |
|
|
|
metadata = {} |
|
|
|
|
|
if http_proxy: |
|
metadata["http_proxy"] = True |
|
if direct_domain: |
|
metadata["domain"] = f"{protocol}://{service_ip}:{service_port}" |
|
else: |
|
metadata["domain"] = f"{domain}/port/{service_port}" |
|
else: |
|
metadata["http_proxy"] = False |
|
metadata["domain"] = f"{protocol}://{service_ip}:{service_port}" |
|
response = client.add_naming_instance( |
|
service_name, |
|
service_ip, |
|
service_port, |
|
cluster_name, |
|
weight, |
|
metadata, |
|
enable=True, |
|
healthy=True, |
|
ephemeral=True, |
|
heartbeat_interval=health_check_interval |
|
) |
|
return response |
|
except Exception as e: |
|
print(f"Error registering service to Nacos: {e}") |
|
return True |
|
|
|
class ModerationsRequest(BaseModel): |
|
model: str = "text-moderation-latest" |
|
input: Union[str, list[str]] |
|
@app.on_event("startup") |
|
def startup_event(): |
|
|
|
config = configparser.ConfigParser() |
|
|
|
if not config.read('config.ini'): |
|
raise RuntimeError("配置文件不存在") |
|
|
|
|
|
NACOS_SERVER = config['nacos']['nacos_server'] |
|
NAMESPACE = config['nacos']['namespace'] |
|
CLUSTER_NAME = config['nacos']['cluster_name'] |
|
client = nacos.NacosClient(NACOS_SERVER, namespace=NAMESPACE, username=config['nacos']['username'], password=config['nacos']['password']) |
|
SERVICE_NAME = config['nacos']['service_name'] |
|
HEALTH_CHECK_INTERVAL = int(config['nacos']['health_check_interval']) |
|
WEIGHT = int(config.get('nacos', 'weight', fallback='1')) |
|
HTTP_PROXY = config.getboolean('server', 'http_proxy') |
|
DOMAIN = config['server']['domain'] |
|
PROTOCOL = config['server']['protocol'] |
|
DIRECT_DOMAIN = config.getboolean('server', 'direct_domain') |
|
|
|
|
|
autodl_url = os.environ.get('AutoDLServiceURL') |
|
if not autodl_url: |
|
raise RuntimeError("Error: AutoDLServiceURL environment variable is not set.") |
|
|
|
parsed_url = urlparse(autodl_url) |
|
SERVICE_IP = parsed_url.hostname |
|
SERVICE_PORT = parsed_url.port |
|
if not SERVICE_IP or not SERVICE_PORT: |
|
raise RuntimeError("Error: Invalid AutoDLServiceURL format.") |
|
|
|
|
|
if not register_service(client, SERVICE_NAME, SERVICE_IP, SERVICE_PORT, CLUSTER_NAME, HEALTH_CHECK_INTERVAL, WEIGHT, HTTP_PROXY, DOMAIN, PROTOCOL, DIRECT_DOMAIN): |
|
raise RuntimeError("Service is healthy but failed to register.") |
|
@app.post("/v1/moderations") |
|
async def moderations(request: ModerationsRequest): |
|
results = { |
|
"id": f"modr-{int(time.time()*1e9)}", |
|
"model": "text-moderation-005", |
|
"results": [], |
|
} |
|
if isinstance(request.input, str): |
|
request.input = [request.input] |
|
|
|
thresholds = { |
|
"sexual": 0.5, |
|
"hate": 0.5, |
|
"harassment": 0.5, |
|
"self_harm": 0.5, |
|
"sexual_minors": 0.9, |
|
"hate_threatening": 0.9, |
|
"violence_graphic": 0.9, |
|
"self_harm_intent": 0.9, |
|
"self_harm_instructions": 0.9, |
|
"harassment_threatening": 0.9, |
|
"violence": 0.5, |
|
} |
|
|
|
for text in request.input: |
|
predictions = await predict(text, model, tokenizer) |
|
category_scores = {labels[i]: predictions[0][i].item() for i in range(len(labels))} |
|
detect = {key: score > thresholds[key] for key, score in category_scores.items()} |
|
detected = any(detect.values()) |
|
|
|
results['results'].append({ |
|
'flagged': detected, |
|
'categories': detect, |
|
'category_scores': category_scores, |
|
}) |
|
|
|
return results |
|
def sigmoid(x): |
|
return 1/(1 + np.exp(-x)) |
|
|
|
def parse_args(argv): |
|
parser = argparse.ArgumentParser(description='Moderation API') |
|
parser.add_argument('--host', type=str, default='0.0.0.0') |
|
parser.add_argument('--port', type=int, default=5002) |
|
parser.add_argument('--test-load', action='store_true') |
|
return parser.parse_args(argv) |
|
|
|
async def predict(text, model, tokenizer): |
|
encoding = tokenizer.encode_plus( |
|
text, |
|
return_tensors='pt' |
|
) |
|
input_ids = encoding['input_ids'].to(device) |
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
def _predict(): |
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
return torch.sigmoid(outputs.logits) |
|
|
|
loop = asyncio.get_running_loop() |
|
predictions = await loop.run_in_executor(None, _predict) |
|
|
|
|
|
del input_ids |
|
del attention_mask |
|
torch.cuda.empty_cache() |
|
|
|
return predictions |
|
|
|
if __name__ == "__main__": |
|
|
|
uvicorn.run("moderations:app", host="0.0.0.0", port=6006, reload=True) |
|
|