rifatramadhani commited on
Commit
b18d110
·
1 Parent(s): 0d37404

refactor: output structure

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -8,8 +8,13 @@ import spaces
8
  import logging
9
  import datetime
10
 
 
 
 
 
11
  @spaces.GPU
12
  def classify(query):
 
13
  model = Detoxify("unbiased-small", device="cuda")
14
 
15
  all_result = []
@@ -25,27 +30,31 @@ def classify(query):
25
  data = [query]
26
  pass
27
 
 
28
  for i in range(len(data)):
29
  result = {}
30
- start_time = datetime.datetime.now()
31
-
32
  df = pd.DataFrame(model.predict(str(data[i])), index=[0])
33
  columns = df.columns
34
 
35
  for i, label in enumerate(columns):
36
  result[label] = df[label][0].round(3).astype("float")
37
 
38
- end_time = datetime.datetime.now()
39
- elapsed_time = end_time - start_time
40
- result["time"] = str(elapsed_time)
41
-
42
- logging.debug("elapsed predict time: %s", str(elapsed_time))
43
- print("elapsed predict time:", str(elapsed_time))
44
-
45
  all_result.append(result)
46
 
47
-
48
- return json.dumps(all_result) if request_type == list else all_result[0]
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
51
  demo.launch()
 
8
  import logging
9
  import datetime
10
 
11
+ # Load model for first time cache
12
+ model = Detoxify("unbiased-small")
13
+
14
+
15
  @spaces.GPU
16
  def classify(query):
17
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
18
  model = Detoxify("unbiased-small", device="cuda")
19
 
20
  all_result = []
 
30
  data = [query]
31
  pass
32
 
33
+ start_time = datetime.datetime.now()
34
  for i in range(len(data)):
35
  result = {}
36
+
 
37
  df = pd.DataFrame(model.predict(str(data[i])), index=[0])
38
  columns = df.columns
39
 
40
  for i, label in enumerate(columns):
41
  result[label] = df[label][0].round(3).astype("float")
42
 
 
 
 
 
 
 
 
43
  all_result.append(result)
44
 
45
+ end_time = datetime.datetime.now()
46
+ elapsed_time = end_time - start_time
47
+
48
+ logging.debug("elapsed predict time: %s", str(elapsed_time))
49
+ print("elapsed predict time:", str(elapsed_time))
50
+
51
+ output = {}
52
+ output["time"] = str(elapsed_time)
53
+ output["device"] = torch_device
54
+ output["result"] = all_result
55
+
56
+ return json.dumps(output)
57
+
58
 
59
  demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
60
  demo.launch()