File size: 2,919 Bytes
ef781c2
59d39d4
ef781c2
d788666
9d69587
 
ef781c2
 
 
d788666
9d69587
ef781c2
 
 
59d39d4
ef781c2
 
 
 
9d69587
ef781c2
 
 
 
 
 
 
 
 
 
 
9d69587
 
 
 
 
 
d788666
59d39d4
 
 
 
 
 
 
 
 
 
d788666
59d39d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d69587
 
 
 
ef781c2
 
 
 
 
 
9d69587
 
59d39d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef781c2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import duckdb
from embedding import get_embeddings
from config import DUCKDB_FILE
from config import PODCAST_TITLE_LIST, EPISODES_PARQUET


def create_table():
    conn = duckdb.connect(DUCKDB_FILE)
    podcasts_create = """CREATE TABLE podcasts (
        id BIGINT PRIMARY KEY,
        title TEXT, date DATE, guests TEXT[], length BIGINT, audio TEXT
        );
    """
    episodes_create = """CREATE TABLE episodes (
        id BIGINT, part BIGINT, start BIGINT, end_ BIGINT, text TEXT,
        PRIMARY KEY (id, part)
        );
    """
    embeddings_create = """CREATE TABLE embeddings (
        id BIGINT, part BIGINT, embedding FLOAT[1024],
        PRIMARY KEY (id, part)
        );
    """
    conn.execute(podcasts_create)
    conn.execute(episodes_create)
    conn.execute(embeddings_create)
    conn.commit()
    conn.close()
    print("Tables created.")


def insert_podcast():
    conn = duckdb.connect(DUCKDB_FILE)
    sql = """INSERT INTO podcasts
        SELECT id, title, date, [], length, audio
          FROM read_parquet(?);
    """
    conn.execute(sql, [PODCAST_TITLE_LIST])
    conn.commit()
    conn.close()


def insert_episodes():
    conn = duckdb.connect(DUCKDB_FILE)
    sql = """INSERT INTO episodes
        SELECT id, part, start, end_, text
          FROM read_parquet(?);
    """
    conn.execute(sql, [EPISODES_PARQUET])
    conn.commit()
    conn.close()


def embed_store():
    conn = duckdb.connect(DUCKDB_FILE)
    sql_select = """SELECT id, part, text FROM episodes;"""
    data = conn.execute(sql_select).df()
    targets = data["text"].tolist()
    enbeddings = get_embeddings(targets)
    for id_, part, emb in zip(data["id"], data["part"], enbeddings):
        # print(id_, title)
        conn.execute(
            "INSERT INTO embeddings VALUES (?, ?, ?)", (id_, part, emb.tolist())
        )
    conn.commit()
    conn.close()


def create_index():
    conn = duckdb.connect(DUCKDB_FILE)
    conn.execute("LOAD vss;")
    conn.execute("SET hnsw_enable_experimental_persistence=true;")
    conn.execute("""CREATE INDEX embeddings_index 
                      ON embeddings USING HNSW (embedding);""")
    conn.commit()
    conn.close()


if __name__ == "__main__":
    import sys
    args = sys.argv
    if len(args) == 2:
        if args[1] == "create":
            create_table()
        elif args[1] == "podcastinsert":
            insert_podcast()
        elif args[1] == "episodeinsert":
            insert_episodes()
        elif args[1] == "embed":
            embed_store()
        elif args[1] == "index":
            create_index()
        elif args[1] == "all":
            create_table()
            insert_podcast()
            insert_episodes()
            embed_store()
            create_index()
        else:
            print("Usage: python store.py all")
            sys.exit(1)
    else:
        print("Usage: python store.py create")
        sys.exit(1)