rifatramadhani's picture
refactor: output structure
b18d110
raw
history blame
1.43 kB
import torch
import gradio as gr
import os
from detoxify import Detoxify
import pandas as pd
import json
import spaces
import logging
import datetime
# Load model for first time cache
model = Detoxify("unbiased-small")
@spaces.GPU
def classify(query):
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
model = Detoxify("unbiased-small", device="cuda")
all_result = []
request_type = type(query)
try:
data = json.loads(query)
if type(data) != list:
data = [query]
else:
request_type = type(data)
except Exception as e:
print(e)
data = [query]
pass
start_time = datetime.datetime.now()
for i in range(len(data)):
result = {}
df = pd.DataFrame(model.predict(str(data[i])), index=[0])
columns = df.columns
for i, label in enumerate(columns):
result[label] = df[label][0].round(3).astype("float")
all_result.append(result)
end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
logging.debug("elapsed predict time: %s", str(elapsed_time))
print("elapsed predict time:", str(elapsed_time))
output = {}
output["time"] = str(elapsed_time)
output["device"] = torch_device
output["result"] = all_result
return json.dumps(output)
demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
demo.launch()