davanstrien HF staff commited on
Commit
8fbb0f5
·
1 Parent(s): f7b0e5b

chore: Add main.py for dataset search functionality

Browse files
Files changed (1) hide show
  1. main.py +180 -0
main.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sqlite3
3
+ from contextlib import asynccontextmanager
4
+ from fastapi import FastAPI, Query, HTTPException
5
+ from typing import List, Optional
6
+ from pydantic import BaseModel
7
+ from data_loader import refresh_data
8
+ import numpy as np
9
+ from pandas import Timestamp
10
+
11
+
12
+ def get_db_connection():
13
+ conn = sqlite3.connect("datasets.db")
14
+ conn.row_factory = sqlite3.Row
15
+ return conn
16
+
17
+
18
+ def setup_database():
19
+ conn = get_db_connection()
20
+ c = conn.cursor()
21
+ c.execute("""CREATE TABLE IF NOT EXISTS datasets
22
+ (hub_id TEXT PRIMARY KEY,
23
+ likes INTEGER,
24
+ downloads INTEGER,
25
+ tags TEXT,
26
+ created_at INTEGER,
27
+ last_modified INTEGER,
28
+ license TEXT,
29
+ language TEXT,
30
+ config_name TEXT,
31
+ column_names TEXT,
32
+ features TEXT)""")
33
+ c.execute("CREATE INDEX IF NOT EXISTS idx_column_names ON datasets (column_names)")
34
+ conn.commit()
35
+ conn.close()
36
+
37
+
38
+ def serialize_numpy(obj):
39
+ if isinstance(obj, np.ndarray):
40
+ return obj.tolist()
41
+ if isinstance(obj, np.integer):
42
+ return int(obj)
43
+ if isinstance(obj, np.floating):
44
+ return float(obj)
45
+ if isinstance(obj, Timestamp):
46
+ return int(obj.timestamp())
47
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
48
+
49
+
50
+ def insert_data(conn, data):
51
+ c = conn.cursor()
52
+
53
+ created_at = data.get("created_at", 0)
54
+ if isinstance(created_at, Timestamp):
55
+ created_at = int(created_at.timestamp())
56
+
57
+ last_modified = data.get("last_modified", 0)
58
+ if isinstance(last_modified, Timestamp):
59
+ last_modified = int(last_modified.timestamp())
60
+
61
+ c.execute(
62
+ """
63
+ INSERT OR REPLACE INTO datasets
64
+ (hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features)
65
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
66
+ """,
67
+ (
68
+ data["hub_id"],
69
+ data.get("likes", 0),
70
+ data.get("downloads", 0),
71
+ json.dumps(data.get("tags", []), default=serialize_numpy),
72
+ created_at,
73
+ last_modified,
74
+ json.dumps(data.get("license", []), default=serialize_numpy),
75
+ json.dumps(data.get("language", []), default=serialize_numpy),
76
+ data.get("config_name", ""),
77
+ json.dumps(data.get("column_names", []), default=serialize_numpy),
78
+ json.dumps(data.get("features", []), default=serialize_numpy),
79
+ ),
80
+ )
81
+ conn.commit()
82
+
83
+
84
+ @asynccontextmanager
85
+ async def lifespan(app: FastAPI):
86
+ # Startup: Load data into the database
87
+ setup_database()
88
+ conn = get_db_connection()
89
+ datasets = refresh_data()
90
+ for data in datasets:
91
+ insert_data(conn, data)
92
+ conn.close()
93
+ yield
94
+ # Shutdown: You can add any cleanup operations here if needed
95
+ # For example, closing database connections, clearing caches, etc.
96
+
97
+
98
+ app = FastAPI(lifespan=lifespan)
99
+
100
+
101
+ class SearchResponse(BaseModel):
102
+ total: int
103
+ page: int
104
+ page_size: int
105
+ results: List[dict]
106
+
107
+
108
+ @app.get("/search", response_model=SearchResponse)
109
+ async def search_datasets(
110
+ columns: List[str] = Query(...),
111
+ match_all: bool = Query(False),
112
+ page: int = Query(1, ge=1),
113
+ page_size: int = Query(10, ge=1, le=1000),
114
+ ):
115
+ offset = (page - 1) * page_size
116
+ conn = get_db_connection()
117
+ c = conn.cursor()
118
+
119
+ try:
120
+ if match_all:
121
+ query = """
122
+ SELECT COUNT(*) as total FROM datasets
123
+ WHERE (SELECT COUNT(*) FROM json_each(column_names)
124
+ WHERE value IN ({})) = ?
125
+ """.format(",".join("?" * len(columns)))
126
+ c.execute(query, (*columns, len(columns)))
127
+ else:
128
+ query = """
129
+ SELECT COUNT(*) as total FROM datasets
130
+ WHERE EXISTS (
131
+ SELECT 1 FROM json_each(column_names)
132
+ WHERE value IN ({})
133
+ )
134
+ """.format(",".join("?" * len(columns)))
135
+ c.execute(query, columns)
136
+
137
+ total = c.fetchone()["total"]
138
+
139
+ if match_all:
140
+ query = """
141
+ SELECT * FROM datasets
142
+ WHERE (SELECT COUNT(*) FROM json_each(column_names)
143
+ WHERE value IN ({})) = ?
144
+ LIMIT ? OFFSET ?
145
+ """.format(",".join("?" * len(columns)))
146
+ c.execute(query, (*columns, len(columns), page_size, offset))
147
+ else:
148
+ query = """
149
+ SELECT * FROM datasets
150
+ WHERE EXISTS (
151
+ SELECT 1 FROM json_each(column_names)
152
+ WHERE value IN ({})
153
+ )
154
+ LIMIT ? OFFSET ?
155
+ """.format(",".join("?" * len(columns)))
156
+ c.execute(query, (*columns, page_size, offset))
157
+
158
+ results = [dict(row) for row in c.fetchall()]
159
+
160
+ for result in results:
161
+ result["tags"] = json.loads(result["tags"])
162
+ result["license"] = json.loads(result["license"])
163
+ result["language"] = json.loads(result["language"])
164
+ result["column_names"] = json.loads(result["column_names"])
165
+ result["features"] = json.loads(result["features"])
166
+
167
+ return SearchResponse(
168
+ total=total, page=page, page_size=page_size, results=results
169
+ )
170
+
171
+ except sqlite3.Error as e:
172
+ raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e
173
+ finally:
174
+ conn.close()
175
+
176
+
177
+ if __name__ == "__main__":
178
+ import uvicorn
179
+
180
+ uvicorn.run(app, host="0.0.0.0", port=8000)