hitz02 commited on
Commit
e57fe3b
·
1 Parent(s): 37e8a32

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow.compat.v1 as tf
2
+ import os
3
+ import shutil
4
+ import csv
5
+ import pandas as pd
6
+ import numpy as np
7
+ import IPython
8
+ import streamlit as st
9
+ import subprocess
10
+ from itertools import islice
11
+ import random
12
+ #from transformers import pipeline
13
+ from transformers import TapasTokenizer, TapasForQuestionAnswering
14
+
15
+ tf.get_logger().setLevel('ERROR')
16
+
17
+ def install(package):
18
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
19
+
20
+ install('torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html')
21
+
22
+ model_name = 'google/tapas-base-finetuned-wtq'
23
+ #model_name = "table-question-answering"
24
+ #model = pipeline(model_name)
25
+
26
+ model = TapasForQuestionAnswering.from_pretrained(model_name, local_files_only=False)
27
+ tokenizer = TapasTokenizer.from_pretrained(model_name)
28
+
29
+ st.set_option('deprecation.showfileUploaderEncoding', False)
30
+
31
+ st.title('Query your Table')
32
+ st.header('Upload CSV file')
33
+
34
+ uploaded_file = st.file_uploader("Choose your CSV file",type = 'csv')
35
+ placeholder = st.empty()
36
+
37
+ if uploaded_file is not None:
38
+ data = pd.read_csv(uploaded_file)
39
+ data.replace(',','', regex=True, inplace=True)
40
+ if st.checkbox('Want to see the data?'):
41
+ placeholder.dataframe(data)
42
+
43
+ st.header('Enter your queries')
44
+ input_queries = st.text_input('Type your queries separated by comma(,)',value='')
45
+ input_queries = input_queries.split(',')
46
+
47
+ colors1 = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(len(input_queries))]
48
+ colors2 = ['background-color:'+str(color)+'; color: black' for color in colors1]
49
+
50
+ def styling_specific_cell(x,tags,colors):
51
+ df_styler = pd.DataFrame('', index=x.index, columns=x.columns)
52
+ for idx,tag in enumerate(tags):
53
+ for r,c in tag:
54
+ df_styler.iloc[r, c] = colors[idx]
55
+ return df_styler
56
+
57
+ if st.button('Predict Answers'):
58
+ with st.spinner('It will take approx a minute'):
59
+ data = data.astype(str)
60
+ inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
61
+ outputs = model(**inputs)
62
+ #outputs = model(table = data, query = queries)
63
+ predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach())
64
+
65
+ id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
66
+ aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
67
+
68
+ answers = []
69
+
70
+ for coordinates in predicted_answer_coordinates:
71
+ if len(coordinates) == 1:
72
+ # only a single cell:
73
+ answers.append(table.iat[coordinates[0]])
74
+ else:
75
+ # multiple cells
76
+ cell_values = []
77
+ for coordinate in coordinates:
78
+ cell_values.append(table.iat[coordinate])
79
+ answers.append(", ".join(cell_values))
80
+
81
+ st.success('Done! Please check below the answers and its cells highlighted in table above')
82
+
83
+ placeholder.dataframe(data.style.apply(styling_specific_cell,tags=predicted_answer_coordinates,colors=colors2,axis=None))
84
+
85
+ for query, answer, predicted_agg, c in zip(queries, answers, aggregation_predictions_string, colors1):
86
+ st.write('\n')
87
+ st.markdown('<font color={} size=4>**{}**</font>'.format(c,query), unsafe_allow_html=True)
88
+ st.write('\n')
89
+
90
+ if predicted_agg == "NONE" or predicted_agg == 'COUNT':
91
+ st.markdown('**>** '+str(answer))
92
+ else:
93
+ if predicted_agg == 'SUM':
94
+ st.markdown('**>** '+str(sum(answer.split(','))))
95
+ else:
96
+ st.markdown('**>** '+str(np.round(np.mean(answer.split(',')),2)))