davanstrien HF staff commited on
Commit
13dd954
·
1 Parent(s): 6b04271
Files changed (3) hide show
  1. app.py +108 -0
  2. requirements.in +4 -0
  3. requirements.txt +294 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client import models
4
+ from sentence_transformers import SentenceTransformer
5
+ from huggingface_hub import hf_hub_url
6
+ from dotenv import load_dotenv
7
+ import os
8
+ from functools import lru_cache
9
+
10
+ load_dotenv()
11
+
12
+ URL = os.getenv("QDRANT_URL")
13
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
14
+ sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en")
15
+
16
+ print(URL)
17
+ print(QDRANT_API_KEY)
18
+ collection_name = "dataset_cards"
19
+ client = QdrantClient(
20
+ url=URL,
21
+ api_key=QDRANT_API_KEY,
22
+ )
23
+
24
+
25
+ def format_results(results):
26
+ markdown = ""
27
+ for result in results:
28
+ hub_id = result.payload["id"]
29
+ url = hf_hub_url(hub_id, "README.md", repo_type="dataset")
30
+ header = f"## [{hub_id}]({url})"
31
+ markdown += header + "\n"
32
+ markdown += result.payload["section_text"] + "\n"
33
+ return markdown
34
+
35
+
36
+ @lru_cache()
37
+ def search(query: str):
38
+ query_ = sentence_embedding_model.encode(
39
+ f"Represent this sentence for searching relevant passages:{query}"
40
+ )
41
+ results = client.search(
42
+ collection_name="dataset_cards",
43
+ query_vector=query_,
44
+ limit=10,
45
+ )
46
+ return format_results(results)
47
+
48
+
49
+ @lru_cache()
50
+ def hub_id_qdrant_id(hub_id):
51
+ matches = client.scroll(
52
+ collection_name="dataset_cards",
53
+ scroll_filter=models.Filter(
54
+ must=[
55
+ models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)),
56
+ ]
57
+ ),
58
+ limit=1,
59
+ with_payload=True,
60
+ with_vectors=False,
61
+ )
62
+ try:
63
+ return matches[0][0].id
64
+ except IndexError as e:
65
+ raise gr.Error(
66
+ f"Hub id {hub_id} not in out database. This could be because it is very new or because it doesn't have much documentation."
67
+ ) from e
68
+
69
+
70
+ @lru_cache()
71
+ def recommend(hub_id):
72
+ positive_id = hub_id_qdrant_id(hub_id)
73
+ results = client.recommend(collection_name=collection_name, positive=[positive_id])
74
+ return format_results(results)
75
+
76
+
77
+ def query(search_term, search_type):
78
+ if search_type == "Recommend similar datasets":
79
+ return recommend(search_term)
80
+ else:
81
+ return search(search_term)
82
+
83
+
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("## 🤗 Sematic dataset search")
86
+ with gr.Row():
87
+ gr.Markdown(
88
+ "This Gradio app allows you to search for datasets based on their descriptions. You can either search for similar datasets to a given dataset or search for datasets based on a query."
89
+ )
90
+ with gr.Row():
91
+ search_term = gr.Textbox(value="movie review sentiment",
92
+ label="hub id i.e. IMDB or query i.e. movie review sentiment"
93
+ )
94
+ with gr.Row():
95
+ with gr.Row():
96
+ find_similar_btn = gr.Button("Search")
97
+ search_type = gr.Radio(
98
+ ["Recommend similar datasets", "Semantic Search"],
99
+ label="Search type",
100
+ value="Semantic Search",
101
+ interactive=True,
102
+ )
103
+
104
+ results = gr.Markdown()
105
+ find_similar_btn.click(query, [search_term, search_type], results)
106
+
107
+
108
+ demo.launch()
requirements.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ python-dotenv
3
+ qdrant-client==1.3.1
4
+ sentence-transformers
requirements.txt ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with Python 3.11
3
+ # by the following command:
4
+ #
5
+ # pip-compile
6
+ #
7
+ aiofiles==23.2.1
8
+ # via gradio
9
+ aiohttp==3.8.5
10
+ # via gradio
11
+ aiosignal==1.3.1
12
+ # via aiohttp
13
+ altair==5.0.1
14
+ # via gradio
15
+ anyio==3.7.1
16
+ # via
17
+ # httpcore
18
+ # starlette
19
+ async-timeout==4.0.2
20
+ # via aiohttp
21
+ attrs==23.1.0
22
+ # via
23
+ # aiohttp
24
+ # jsonschema
25
+ # referencing
26
+ certifi==2023.7.22
27
+ # via
28
+ # httpcore
29
+ # httpx
30
+ # requests
31
+ charset-normalizer==3.2.0
32
+ # via
33
+ # aiohttp
34
+ # requests
35
+ click==8.1.6
36
+ # via
37
+ # nltk
38
+ # uvicorn
39
+ contourpy==1.1.0
40
+ # via matplotlib
41
+ cycler==0.11.0
42
+ # via matplotlib
43
+ fastapi==0.101.0
44
+ # via gradio
45
+ ffmpy==0.3.1
46
+ # via gradio
47
+ filelock==3.12.2
48
+ # via
49
+ # huggingface-hub
50
+ # torch
51
+ # transformers
52
+ fonttools==4.42.0
53
+ # via matplotlib
54
+ frozenlist==1.4.0
55
+ # via
56
+ # aiohttp
57
+ # aiosignal
58
+ fsspec==2023.6.0
59
+ # via
60
+ # gradio-client
61
+ # huggingface-hub
62
+ gradio==3.39.0
63
+ # via -r requirements.in
64
+ gradio-client==0.3.0
65
+ # via gradio
66
+ grpcio==1.56.2
67
+ # via
68
+ # grpcio-tools
69
+ # qdrant-client
70
+ grpcio-tools==1.56.2
71
+ # via qdrant-client
72
+ h11==0.14.0
73
+ # via
74
+ # httpcore
75
+ # uvicorn
76
+ h2==4.1.0
77
+ # via httpx
78
+ hpack==4.0.0
79
+ # via h2
80
+ httpcore==0.17.3
81
+ # via httpx
82
+ httpx[http2]==0.24.1
83
+ # via
84
+ # gradio
85
+ # gradio-client
86
+ # qdrant-client
87
+ huggingface-hub==0.16.4
88
+ # via
89
+ # gradio
90
+ # gradio-client
91
+ # sentence-transformers
92
+ # transformers
93
+ hyperframe==6.0.1
94
+ # via h2
95
+ idna==3.4
96
+ # via
97
+ # anyio
98
+ # httpx
99
+ # requests
100
+ # yarl
101
+ jinja2==3.1.2
102
+ # via
103
+ # altair
104
+ # gradio
105
+ # torch
106
+ joblib==1.3.2
107
+ # via
108
+ # nltk
109
+ # scikit-learn
110
+ jsonschema==4.19.0
111
+ # via altair
112
+ jsonschema-specifications==2023.7.1
113
+ # via jsonschema
114
+ kiwisolver==1.4.4
115
+ # via matplotlib
116
+ linkify-it-py==2.0.2
117
+ # via markdown-it-py
118
+ markdown-it-py[linkify]==2.2.0
119
+ # via
120
+ # gradio
121
+ # mdit-py-plugins
122
+ markupsafe==2.1.3
123
+ # via
124
+ # gradio
125
+ # jinja2
126
+ matplotlib==3.7.2
127
+ # via gradio
128
+ mdit-py-plugins==0.3.3
129
+ # via gradio
130
+ mdurl==0.1.2
131
+ # via markdown-it-py
132
+ mpmath==1.3.0
133
+ # via sympy
134
+ multidict==6.0.4
135
+ # via
136
+ # aiohttp
137
+ # yarl
138
+ networkx==3.1
139
+ # via torch
140
+ nltk==3.8.1
141
+ # via sentence-transformers
142
+ numpy==1.25.2
143
+ # via
144
+ # altair
145
+ # contourpy
146
+ # gradio
147
+ # matplotlib
148
+ # pandas
149
+ # qdrant-client
150
+ # scikit-learn
151
+ # scipy
152
+ # sentence-transformers
153
+ # torchvision
154
+ # transformers
155
+ orjson==3.9.4
156
+ # via gradio
157
+ packaging==23.1
158
+ # via
159
+ # gradio
160
+ # gradio-client
161
+ # huggingface-hub
162
+ # matplotlib
163
+ # transformers
164
+ pandas==2.0.3
165
+ # via
166
+ # altair
167
+ # gradio
168
+ pillow==10.0.0
169
+ # via
170
+ # gradio
171
+ # matplotlib
172
+ # torchvision
173
+ portalocker==2.7.0
174
+ # via qdrant-client
175
+ protobuf==4.24.0
176
+ # via grpcio-tools
177
+ pydantic==1.10.12
178
+ # via
179
+ # fastapi
180
+ # gradio
181
+ # qdrant-client
182
+ pydub==0.25.1
183
+ # via gradio
184
+ pyparsing==3.0.9
185
+ # via matplotlib
186
+ python-dateutil==2.8.2
187
+ # via
188
+ # matplotlib
189
+ # pandas
190
+ python-dotenv==1.0.0
191
+ # via -r requirements.in
192
+ python-multipart==0.0.6
193
+ # via gradio
194
+ pytz==2023.3
195
+ # via pandas
196
+ pyyaml==6.0.1
197
+ # via
198
+ # gradio
199
+ # huggingface-hub
200
+ # transformers
201
+ qdrant-client==1.3.1
202
+ # via -r requirements.in
203
+ referencing==0.30.2
204
+ # via
205
+ # jsonschema
206
+ # jsonschema-specifications
207
+ regex==2023.8.8
208
+ # via
209
+ # nltk
210
+ # transformers
211
+ requests==2.31.0
212
+ # via
213
+ # gradio
214
+ # gradio-client
215
+ # huggingface-hub
216
+ # torchvision
217
+ # transformers
218
+ rpds-py==0.9.2
219
+ # via
220
+ # jsonschema
221
+ # referencing
222
+ safetensors==0.3.2
223
+ # via transformers
224
+ scikit-learn==1.3.0
225
+ # via sentence-transformers
226
+ scipy==1.11.1
227
+ # via
228
+ # scikit-learn
229
+ # sentence-transformers
230
+ semantic-version==2.10.0
231
+ # via gradio
232
+ sentence-transformers==2.2.2
233
+ # via -r requirements.in
234
+ sentencepiece==0.1.99
235
+ # via sentence-transformers
236
+ six==1.16.0
237
+ # via python-dateutil
238
+ sniffio==1.3.0
239
+ # via
240
+ # anyio
241
+ # httpcore
242
+ # httpx
243
+ starlette==0.27.0
244
+ # via fastapi
245
+ sympy==1.12
246
+ # via torch
247
+ threadpoolctl==3.2.0
248
+ # via scikit-learn
249
+ tokenizers==0.13.3
250
+ # via transformers
251
+ toolz==0.12.0
252
+ # via altair
253
+ torch==2.0.1
254
+ # via
255
+ # sentence-transformers
256
+ # torchvision
257
+ torchvision==0.15.2
258
+ # via sentence-transformers
259
+ tqdm==4.66.0
260
+ # via
261
+ # huggingface-hub
262
+ # nltk
263
+ # sentence-transformers
264
+ # transformers
265
+ transformers==4.31.0
266
+ # via sentence-transformers
267
+ typing-extensions==4.5.0
268
+ # via
269
+ # fastapi
270
+ # gradio
271
+ # gradio-client
272
+ # huggingface-hub
273
+ # pydantic
274
+ # qdrant-client
275
+ # torch
276
+ tzdata==2023.3
277
+ # via pandas
278
+ uc-micro-py==1.0.2
279
+ # via linkify-it-py
280
+ urllib3==1.26.16
281
+ # via
282
+ # qdrant-client
283
+ # requests
284
+ uvicorn==0.23.2
285
+ # via gradio
286
+ websockets==11.0.3
287
+ # via
288
+ # gradio
289
+ # gradio-client
290
+ yarl==1.9.2
291
+ # via aiohttp
292
+
293
+ # The following packages are considered to be unsafe in a requirements file:
294
+ # setuptools