File size: 4,058 Bytes
e57fe3b
 
 
 
5a03966
e57fe3b
 
 
 
 
 
 
 
 
 
 
 
f27c7ea
d99dd27
 
c1a9068
ceb06f2
f469d56
 
d99dd27
f27c7ea
e57fe3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import tensorflow.compat.v1 as tf
import os 
import shutil
import csv
import sys
import pandas as pd
import numpy as np
import IPython
import streamlit as st
import subprocess
from itertools import islice
import random
#from transformers import pipeline
from transformers import TapasTokenizer, TapasForQuestionAnswering

tf.get_logger().setLevel('ERROR')

#def install(package):
#subprocess.run("python -m pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html", shell=True)
try:
    print(sys.executable)
    subprocess.check_call(["/home/user/.local/lib/python3.8", "-m", "pip", "install", 'torch-scatter','-f', 'https://data.pyg.org/whl/torch-1.10.0+cu102.html'])
except Exception as e:
    print('Error..', str(e))
    
#install('torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu102.html')

model_name = 'google/tapas-base-finetuned-wtq'
#model_name =  "table-question-answering"
#model = pipeline(model_name)

model = TapasForQuestionAnswering.from_pretrained(model_name, local_files_only=False)
tokenizer = TapasTokenizer.from_pretrained(model_name)

st.set_option('deprecation.showfileUploaderEncoding', False)

st.title('Query your Table')
st.header('Upload CSV file')

uploaded_file = st.file_uploader("Choose your CSV file",type = 'csv')
placeholder = st.empty()

if uploaded_file is not None:
    data = pd.read_csv(uploaded_file)
    data.replace(',','', regex=True, inplace=True)
    if st.checkbox('Want to see the data?'):
        placeholder.dataframe(data)

st.header('Enter your queries')
input_queries = st.text_input('Type your queries separated by comma(,)',value='')
input_queries = input_queries.split(',')

colors1 = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(len(input_queries))]
colors2 = ['background-color:'+str(color)+'; color: black' for color in colors1]

def styling_specific_cell(x,tags,colors):
    df_styler = pd.DataFrame('', index=x.index, columns=x.columns)
    for idx,tag in enumerate(tags):
        for r,c in tag:
            df_styler.iloc[r, c] = colors[idx]
    return df_styler
    
if st.button('Predict Answers'):
    with st.spinner('It will take approx a minute'):
        data = data.astype(str)
        inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt")
        outputs = model(**inputs)
        #outputs = model(table = data, query = queries)
        predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions( inputs, outputs.logits.detach(), outputs.logits_aggregation.detach())
        
        id2aggregation = {0: "NONE", 1: "SUM", 2: "AVERAGE", 3:"COUNT"}
        aggregation_predictions_string = [id2aggregation[x] for x in predicted_aggregation_indices]
    
        answers = []
        
        for coordinates in predicted_answer_coordinates:
           if len(coordinates) == 1:
             # only a single cell:
             answers.append(table.iat[coordinates[0]])
           else:
             # multiple cells
             cell_values = []
             for coordinate in coordinates:
                cell_values.append(table.iat[coordinate])
             answers.append(", ".join(cell_values))
             
    st.success('Done! Please check below the answers and its cells highlighted in table above')
    
    placeholder.dataframe(data.style.apply(styling_specific_cell,tags=predicted_answer_coordinates,colors=colors2,axis=None))
      
    for query, answer, predicted_agg, c in zip(queries, answers, aggregation_predictions_string, colors1):
        st.write('\n')
        st.markdown('<font color={} size=4>**{}**</font>'.format(c,query), unsafe_allow_html=True)
        st.write('\n')
        
        if predicted_agg == "NONE" or predicted_agg == 'COUNT':
            st.markdown('**>** '+str(answer))
        else:
            if predicted_agg == 'SUM':
                st.markdown('**>** '+str(sum(answer.split(','))))
            else:
                st.markdown('**>** '+str(np.round(np.mean(answer.split(',')),2)))