Fangrui Liu commited on
Commit
725da8c
·
1 Parent(s): c8fbf76

add features and datasets

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +115 -43
  3. requirements.txt +2 -1
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .streamlit
app.py CHANGED
@@ -7,15 +7,22 @@ from transformers import CLIPTokenizerFast, AutoTokenizer
7
  import torch
8
  import logging
9
  from os import environ
 
10
  environ['TOKENIZERS_PARALLELISM'] = 'true'
11
 
12
- from myscaledb import Client
 
 
 
 
13
 
14
  DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
15
  MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
16
  DIMS = 512
17
  # Ignore some bad links (broken in the dataset already)
18
- BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8', 'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
 
 
19
 
20
  @st.experimental_singleton(show_spinner=False)
21
  def init_clip():
@@ -28,6 +35,7 @@ def init_clip():
28
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
29
  return tokenizer, clip
30
 
 
31
  @st.experimental_singleton(show_spinner=False)
32
  def init_db():
33
  """ Initialize the Database Connection
@@ -36,17 +44,20 @@ def init_db():
36
  meta_field: Meta field that records if an image is viewed or not
37
  client: Database connection object
38
  """
39
- client = Client(url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
 
40
  # We can check if the connection is alive
41
  assert client.is_alive()
42
  meta_field = {}
43
  return meta_field, client
44
 
 
45
  @st.experimental_singleton(show_spinner=False)
46
  def init_query_num():
47
  print("init query_num")
48
  return 0
49
 
 
50
  def query(xq, top_k=10):
51
  """ Query TopK matched w.r.t a given vector
52
 
@@ -62,30 +73,29 @@ def query(xq, top_k=10):
62
  while attempt < 3:
63
  try:
64
  xq_s = f"[{', '.join([str(float(fnum)) for fnum in list(xq)])}]"
65
-
66
  print('Excluded pre:', st.session_state.meta)
67
  if len(st.session_state.meta) > 0:
68
- exclude_list = ','.join([f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1])
 
69
  print("Excluded:", exclude_list)
70
  # Using PREWHERE allows you to do column filter before vector search
71
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
72
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
73
- FROM {DB_NAME} PREWHERE id NOT IN ({exclude_list})")
 
74
  else:
75
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
76
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
77
- FROM {DB_NAME}")
78
- # real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
79
- # 1 - arraySum(arrayMap((x, y) -> x * y, {xq_s}, vector)) AS dist\
80
- # FROM {DB_NAME} ORDER BY dist LIMIT {top_k}")
81
- # FIXME: This is causing freezing on DB
82
  real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
83
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
84
- FROM {DB_NAME}")
85
  top_k = real_xc
86
- xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or \
87
- st.session_state.meta[xi['id']] < 1]
88
- logging.info(f'{len(xc)} records returned, {[_i["id"] for _i in xc]}')
 
89
  matches = xc
90
  break
91
  except Exception as e:
@@ -98,20 +108,23 @@ def query(xq, top_k=10):
98
  logging.error(f"No matches found for '{DB_NAME}'")
99
  return matches, top_k
100
 
 
101
  @st.experimental_singleton(show_spinner=False)
102
  def init_random_query():
103
  xq = np.random.rand(DIMS).tolist()
104
  return xq, xq.copy()
105
 
 
106
  class Classifier:
107
  """ Zero-shot Classifier
108
  This Classifier provides proxy regarding to the user's reaction to the probed images.
109
  The proxy will replace the original query vector generated by prompted vector and finally
110
  give the user a satisfying retrieval result.
111
-
112
  This can be commonly seen in a recommendation system. The classifier will recommend more
113
  precise result as it accumulating user's activity.
114
  """
 
115
  def __init__(self, xq: list):
116
  # initialize model with DIMS input size and 1 output
117
  # note that the bias is ignored, as we only focus on the inner product result
@@ -122,7 +135,7 @@ class Classifier:
122
  # init loss and optimizer
123
  self.loss = torch.nn.BCEWithLogitsLoss()
124
  self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
125
-
126
  def fit(self, X: list, y: list, iters: int = 5):
127
  # convert X and y to tensor
128
  X = torch.Tensor(X)
@@ -132,7 +145,8 @@ class Classifier:
132
  self.optimizer.zero_grad()
133
  # Normalize the weight before inference
134
  # This will constrain the gradient or you will have an explosion on query vector
135
- self.model.weight.data = self.model.weight.data / torch.norm(self.model.weight.data, p=2, dim=-1)
 
136
  # forward pass
137
  out = self.model(X)
138
  # compute loss
@@ -141,11 +155,17 @@ class Classifier:
141
  loss.backward()
142
  # update weights
143
  self.optimizer.step()
144
-
145
  def get_weights(self):
146
  xq = self.model.weight.detach().numpy()[0].tolist()
147
  return xq
148
 
 
 
 
 
 
 
149
  def prompt2vec(prompt: str):
150
  """ Convert prompt into a computational vector
151
 
@@ -161,6 +181,7 @@ def prompt2vec(prompt: str):
161
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
162
  return xq
163
 
 
164
  def pil_to_bytes(img):
165
  """ Convert a Pillow image into base64
166
 
@@ -176,13 +197,16 @@ def pil_to_bytes(img):
176
  img_bin = base64.b64encode(img_bin).decode('utf-8')
177
  return img_bin
178
 
 
179
  def card(i, url):
180
  return f'<img id="img{i}" src="{url}" width="200px;">'
181
 
 
182
  def card_with_conf(i, conf, url):
183
- conf = "%.4f"%(conf)
184
  return f'<img id="img{i}" src="{url}" width="200px;" style="margin:50px 50px"><div><p><b>Relevance: {conf}</b></p></div>'
185
 
 
186
  def get_top_k(xq, top_k=9):
187
  """ wrapper function for query
188
 
@@ -198,6 +222,7 @@ def get_top_k(xq, top_k=9):
198
  )
199
  return matches
200
 
 
201
  def tune(X, y, iters=2):
202
  """ Train the Zero-shot Classifier
203
 
@@ -206,6 +231,7 @@ def tune(X, y, iters=2):
206
  y (list of floats or numpy.ndarray): Scores given by user
207
  iters (int, optional): iterations of updates to be run
208
  """
 
209
  # train the classifier
210
  st.session_state.clf.fit(X, y, iters=iters)
211
  # extract new vector
@@ -224,17 +250,19 @@ def refresh_index():
224
  st.session_state.meta, st.session_state.index = init_db()
225
  del st.session_state.clf, st.session_state.xq
226
 
 
227
  def calc_dist():
228
  xq = np.array(st.session_state.xq)
229
  orig_xq = np.array(st.session_state.orig_xq)
230
  return np.linalg.norm(xq - orig_xq)
231
 
 
232
  def submit():
233
  """ Tune the model w.r.t given score from user.
234
  """
235
  st.session_state.query_num += 1
236
  matches = st.session_state.matches
237
- velocity = 1 #st.session_state.velocity
238
  scores = {}
239
  states = [
240
  st.session_state[f"input{i}"] for i in range(len(matches))
@@ -253,9 +281,11 @@ def submit():
253
  st.session_state.meta[match['id']] = 1
254
  logging.info(f"Exclude List: {st.session_state.meta}")
255
 
 
256
  def delete_element(element):
257
  del element
258
 
 
259
  st.markdown("""
260
  <link
261
  rel="stylesheet"
@@ -308,32 +338,56 @@ if 'xq' not in st.session_state:
308
  msg = messages[st.session_state.query_num]
309
  else:
310
  msg = messages[-1]
311
-
312
  # Basic Layout
313
-
314
  with st.container():
 
 
315
  st.title("Visual Dataset Explorer")
316
- start = [st.empty(), st.empty(), st.empty(), st.empty(), st.empty()]
 
317
  start[0].info(msg)
318
- prompt = start[1].text_input("Prompt:", value="", placeholder="Examples: white dogs, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
319
- start[2].markdown(
 
 
 
 
 
320
  '<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
321
  <p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
322
  unsafe_allow_html=True)
323
- with start[3]:
 
 
 
 
324
  col = st.columns(8)
 
325
  prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
326
- random_xq = col[7].button("Random", disabled=len(prompt) != 0)
 
 
327
  if random_xq:
328
  # Randomly pick a vector to query
329
  xq, orig_xq = init_random_query()
330
  st.session_state.xq = xq
331
  st.session_state.orig_xq = orig_xq
332
  _ = [elem.empty() for elem in start]
333
- elif prompt_xq:
334
- print(f"Input prompt is {prompt}")
335
- # Tokenize the vectors
336
- xq = prompt2vec(prompt)
 
 
 
 
 
 
 
 
 
 
337
  st.session_state.xq = xq
338
  st.session_state.orig_xq = xq
339
  _ = [elem.empty() for elem in start]
@@ -347,11 +401,21 @@ if 'xq' in st.session_state:
347
  # initialize classifier
348
  if 'clf' not in st.session_state:
349
  st.session_state.clf = Classifier(st.session_state.xq)
350
-
351
  # if we want to display images we end up here
352
  st.info(msg)
353
  # first retrieve images from pinecone
354
- st.session_state.matches, st.session_state.top_k = get_top_k(st.session_state.clf.get_weights(), top_k=9)
 
 
 
 
 
 
 
 
 
 
355
  with st.container():
356
  with st.sidebar:
357
  with st.container():
@@ -364,15 +428,23 @@ if 'xq' in st.session_state:
364
  else:
365
  disable = True
366
  dist = np.matmul(st.session_state.clf.get_weights() / np.linalg.norm(st.session_state.clf.get_weights()),
367
- np.array(k["vector"]).T)
368
- st.markdown(card_with_conf(i, dist, url), unsafe_allow_html=True)
369
-
 
 
 
 
 
 
370
  # once retrieved, display them alongside checkboxes in a form
371
  with st.form("batch", clear_on_submit=False):
372
- st.session_state.iters = st.slider("Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
373
- col = st.columns([1,9])
 
374
  col[0].form_submit_button("Train!", on_click=submit)
375
- col[1].form_submit_button("Choose a new prompt", on_click=refresh_index)
 
376
  # we have three columns in the form
377
  cols = st.columns(3)
378
  for i, match in enumerate(st.session_state.matches):
@@ -384,9 +456,9 @@ if 'xq' in st.session_state:
384
  else:
385
  disable = True
386
  # the card shows an image and a checkbox
387
- cols[i%3].markdown(card(i, url), unsafe_allow_html=True)
388
  # we access the values of the checkbox via st.session_state[f"input{i}"]
389
- cols[i%3].slider(
390
  "Relevance",
391
  min_value=0.0,
392
  max_value=1.0,
@@ -394,4 +466,4 @@ if 'xq' in st.session_state:
394
  step=0.05,
395
  key=f"input{i}",
396
  disabled=disabled
397
- )
 
7
  import torch
8
  import logging
9
  from os import environ
10
+ from myscaledb import Client
11
  environ['TOKENIZERS_PARALLELISM'] = 'true'
12
 
13
+
14
+ db_name_map = {
15
+ "Unsplash Photos 25K": "mqdb_demo.unsplash_25k_clip_indexer",
16
+ "RSICD: Remote Sensing Images 11K": "mqdb_demo.rsicd_clip_b_32",
17
+ }
18
 
19
  DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
20
  MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
21
  DIMS = 512
22
  # Ignore some bad links (broken in the dataset already)
23
+ BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
24
+ 'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
25
+
26
 
27
  @st.experimental_singleton(show_spinner=False)
28
  def init_clip():
 
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
36
  return tokenizer, clip
37
 
38
+
39
  @st.experimental_singleton(show_spinner=False)
40
  def init_db():
41
  """ Initialize the Database Connection
 
44
  meta_field: Meta field that records if an image is viewed or not
45
  client: Database connection object
46
  """
47
+ client = Client(
48
+ url=st.secrets["DB_URL"], user=st.secrets["USER"], password=st.secrets["PASSWD"])
49
  # We can check if the connection is alive
50
  assert client.is_alive()
51
  meta_field = {}
52
  return meta_field, client
53
 
54
+
55
  @st.experimental_singleton(show_spinner=False)
56
  def init_query_num():
57
  print("init query_num")
58
  return 0
59
 
60
+
61
  def query(xq, top_k=10):
62
  """ Query TopK matched w.r.t a given vector
63
 
 
73
  while attempt < 3:
74
  try:
75
  xq_s = f"[{', '.join([str(float(fnum)) for fnum in list(xq)])}]"
76
+
77
  print('Excluded pre:', st.session_state.meta)
78
  if len(st.session_state.meta) > 0:
79
+ exclude_list = ','.join(
80
+ [f'\'{i}\'' for i, v in st.session_state.meta.items() if v >= 1])
81
  print("Excluded:", exclude_list)
82
  # Using PREWHERE allows you to do column filter before vector search
83
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
84
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
85
+ FROM {db_name_map[st.session_state.db_name_ref]} \
86
+ PREWHERE id NOT IN ({exclude_list})")
87
  else:
88
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
89
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
90
+ FROM {db_name_map[st.session_state.db_name_ref]}")
 
 
 
 
91
  real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
92
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
93
+ FROM {db_name_map[st.session_state.db_name_ref]}")
94
  top_k = real_xc
95
+ xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
96
+ st.session_state.meta[xi['id']] < 1]
97
+ logging.info(
98
+ f'{len(xc)} records returned, {[_i["id"] for _i in xc]}')
99
  matches = xc
100
  break
101
  except Exception as e:
 
108
  logging.error(f"No matches found for '{DB_NAME}'")
109
  return matches, top_k
110
 
111
+
112
  @st.experimental_singleton(show_spinner=False)
113
  def init_random_query():
114
  xq = np.random.rand(DIMS).tolist()
115
  return xq, xq.copy()
116
 
117
+
118
  class Classifier:
119
  """ Zero-shot Classifier
120
  This Classifier provides proxy regarding to the user's reaction to the probed images.
121
  The proxy will replace the original query vector generated by prompted vector and finally
122
  give the user a satisfying retrieval result.
123
+
124
  This can be commonly seen in a recommendation system. The classifier will recommend more
125
  precise result as it accumulating user's activity.
126
  """
127
+
128
  def __init__(self, xq: list):
129
  # initialize model with DIMS input size and 1 output
130
  # note that the bias is ignored, as we only focus on the inner product result
 
135
  # init loss and optimizer
136
  self.loss = torch.nn.BCEWithLogitsLoss()
137
  self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
138
+
139
  def fit(self, X: list, y: list, iters: int = 5):
140
  # convert X and y to tensor
141
  X = torch.Tensor(X)
 
145
  self.optimizer.zero_grad()
146
  # Normalize the weight before inference
147
  # This will constrain the gradient or you will have an explosion on query vector
148
+ self.model.weight.data = self.model.weight.data / \
149
+ torch.norm(self.model.weight.data, p=2, dim=-1)
150
  # forward pass
151
  out = self.model(X)
152
  # compute loss
 
155
  loss.backward()
156
  # update weights
157
  self.optimizer.step()
158
+
159
  def get_weights(self):
160
  xq = self.model.weight.detach().numpy()[0].tolist()
161
  return xq
162
 
163
+
164
+ class NormalizingLayer(torch.nn.Module):
165
+ def forward(self, x):
166
+ return x / torch.norm(x, dim=-1, keepdim=True)
167
+
168
+
169
  def prompt2vec(prompt: str):
170
  """ Convert prompt into a computational vector
171
 
 
181
  xq = out.squeeze(0).cpu().detach().numpy().tolist()
182
  return xq
183
 
184
+
185
  def pil_to_bytes(img):
186
  """ Convert a Pillow image into base64
187
 
 
197
  img_bin = base64.b64encode(img_bin).decode('utf-8')
198
  return img_bin
199
 
200
+
201
  def card(i, url):
202
  return f'<img id="img{i}" src="{url}" width="200px;">'
203
 
204
+
205
  def card_with_conf(i, conf, url):
206
+ conf = "%.4f" % (conf)
207
  return f'<img id="img{i}" src="{url}" width="200px;" style="margin:50px 50px"><div><p><b>Relevance: {conf}</b></p></div>'
208
 
209
+
210
  def get_top_k(xq, top_k=9):
211
  """ wrapper function for query
212
 
 
222
  )
223
  return matches
224
 
225
+
226
  def tune(X, y, iters=2):
227
  """ Train the Zero-shot Classifier
228
 
 
231
  y (list of floats or numpy.ndarray): Scores given by user
232
  iters (int, optional): iterations of updates to be run
233
  """
234
+ assert len(X) == len(y)
235
  # train the classifier
236
  st.session_state.clf.fit(X, y, iters=iters)
237
  # extract new vector
 
250
  st.session_state.meta, st.session_state.index = init_db()
251
  del st.session_state.clf, st.session_state.xq
252
 
253
+
254
  def calc_dist():
255
  xq = np.array(st.session_state.xq)
256
  orig_xq = np.array(st.session_state.orig_xq)
257
  return np.linalg.norm(xq - orig_xq)
258
 
259
+
260
  def submit():
261
  """ Tune the model w.r.t given score from user.
262
  """
263
  st.session_state.query_num += 1
264
  matches = st.session_state.matches
265
+ velocity = 1 # st.session_state.velocity
266
  scores = {}
267
  states = [
268
  st.session_state[f"input{i}"] for i in range(len(matches))
 
281
  st.session_state.meta[match['id']] = 1
282
  logging.info(f"Exclude List: {st.session_state.meta}")
283
 
284
+
285
  def delete_element(element):
286
  del element
287
 
288
+
289
  st.markdown("""
290
  <link
291
  rel="stylesheet"
 
338
  msg = messages[st.session_state.query_num]
339
  else:
340
  msg = messages[-1]
341
+ prompt = ''
342
  # Basic Layout
 
343
  with st.container():
344
+ if 'prompt' in st.session_state:
345
+ del st.session_state.prompt
346
  st.title("Visual Dataset Explorer")
347
+ start = [st.empty(), st.empty(), st.empty(), st.empty(),
348
+ st.empty(), st.empty(), st.empty()]
349
  start[0].info(msg)
350
+ st.session_state.db_name_ref = start[1].selectbox(
351
+ "Select Database:", list(db_name_map.keys()))
352
+ prompt = start[2].text_input(
353
+ "Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
354
+ if len(prompt) > 0:
355
+ st.session_state.prompt = prompt
356
+ start[3].markdown(
357
  '<p style="color:gray;"> Don\'t know what to search? Try <b>Random</b>!</p>\
358
  <p>🌟 We also support multi-language search. Type any language you know to search! ⌨️ </p>',
359
  unsafe_allow_html=True)
360
+ upld_model = start[5].file_uploader(
361
+ "Or you can upload your previous run!", type='onnx')
362
+ upld_btn = start[6].button(
363
+ "Used Loaded Weights", disabled=upld_model is None)
364
+ with start[4]:
365
  col = st.columns(8)
366
+ has_no_prompt = (len(prompt) == 0 and upld_model is None)
367
  prompt_xq = col[6].button("Prompt", disabled=len(prompt) == 0)
368
+ random_xq = col[7].button("Random", disabled=not (
369
+ len(prompt) == 0 and upld_model is None))
370
+
371
  if random_xq:
372
  # Randomly pick a vector to query
373
  xq, orig_xq = init_random_query()
374
  st.session_state.xq = xq
375
  st.session_state.orig_xq = orig_xq
376
  _ = [elem.empty() for elem in start]
377
+ elif prompt_xq or upld_btn:
378
+ if upld_model is not None:
379
+ # Import vector from a file
380
+ import onnx
381
+ from onnx import numpy_helper
382
+ _model = onnx.load(upld_model)
383
+ weights = _model.graph.initializer
384
+ assert len(weights) == 1
385
+ xq = numpy_helper.to_array(weights[0]).tolist()
386
+ assert len(xq) == DIMS
387
+ else:
388
+ print(f"Input prompt is {prompt}")
389
+ # Tokenize the vectors
390
+ xq = prompt2vec(prompt)
391
  st.session_state.xq = xq
392
  st.session_state.orig_xq = xq
393
  _ = [elem.empty() for elem in start]
 
401
  # initialize classifier
402
  if 'clf' not in st.session_state:
403
  st.session_state.clf = Classifier(st.session_state.xq)
404
+
405
  # if we want to display images we end up here
406
  st.info(msg)
407
  # first retrieve images from pinecone
408
+ st.session_state.matches, st.session_state.top_k = get_top_k(
409
+ st.session_state.clf.get_weights(), top_k=9)
410
+
411
+ # export the model into executable ONNX
412
+ st.session_state.dnld_model = BytesIO()
413
+ torch.onnx.export(torch.nn.Sequential(NormalizingLayer(), st.session_state.clf.model),
414
+ torch.as_tensor(st.session_state.xq).reshape(1, -1),
415
+ st.session_state.dnld_model,
416
+ input_names=['input'],
417
+ output_names=['output'])
418
+
419
  with st.container():
420
  with st.sidebar:
421
  with st.container():
 
428
  else:
429
  disable = True
430
  dist = np.matmul(st.session_state.clf.get_weights() / np.linalg.norm(st.session_state.clf.get_weights()),
431
+ np.array(k["vector"]).T)
432
+ st.markdown(card_with_conf(i, dist, url),
433
+ unsafe_allow_html=True)
434
+ dnld_nam = st.text_input('Download Name:',
435
+ f'{(st.session_state.prompt if "prompt" in st.session_state else (upld_model.name.split(".onnx")[0] if upld_model is not None else "model"))}.onnx',
436
+ max_chars=50)
437
+ dnld_btn = st.download_button('Download your classifier!',
438
+ st.session_state.dnld_model,
439
+ dnld_nam,)
440
  # once retrieved, display them alongside checkboxes in a form
441
  with st.form("batch", clear_on_submit=False):
442
+ st.session_state.iters = st.slider(
443
+ "Number of Iterations to Update", min_value=0, max_value=10, step=1, value=2)
444
+ col = st.columns([1, 9])
445
  col[0].form_submit_button("Train!", on_click=submit)
446
+ col[1].form_submit_button(
447
+ "Choose a new prompt", on_click=refresh_index)
448
  # we have three columns in the form
449
  cols = st.columns(3)
450
  for i, match in enumerate(st.session_state.matches):
 
456
  else:
457
  disable = True
458
  # the card shows an image and a checkbox
459
+ cols[i % 3].markdown(card(i, url), unsafe_allow_html=True)
460
  # we access the values of the checkbox via st.session_state[f"input{i}"]
461
+ cols[i % 3].slider(
462
  "Relevance",
463
  min_value=0.0,
464
  max_value=1.0,
 
466
  step=0.05,
467
  key=f"input{i}",
468
  disabled=disabled
469
+ )
requirements.txt CHANGED
@@ -4,4 +4,5 @@ myscaledb-client
4
  streamlit
5
  multilingual-clip
6
  numpy
7
- torch
 
 
4
  streamlit
5
  multilingual-clip
6
  numpy
7
+ torch
8
+ onnx