File size: 2,731 Bytes
2b515b3
89e36c5
6fd8495
 
 
89e36c5
6fd8495
 
 
 
89e36c5
 
 
 
 
 
 
 
 
 
6fd8495
 
2b32e82
 
 
 
 
 
 
 
 
 
8584df7
2b32e82
 
8584df7
 
6fd8495
 
 
 
 
2b32e82
 
 
8584df7
6fd8495
 
 
 
 
 
 
 
 
8584df7
 
6fd8495
 
 
 
2b32e82
 
 
 
 
 
 
 
 
6fd8495
 
 
 
 
 
2b32e82
 
 
8584df7
6fd8495
 
 
4e75967
6fd8495
 
 
2b515b3
 
 
4e75967
 
 
 
 
2b515b3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from datetime import timedelta
import os
import streamlit as st
import duckdb
from embedding import get_embeddings
from config import HF_HOST, DUCKDB_FILE, HF_REPO_TYPE, HF_REPO_ID, HF_FILENAME


@st.cache_resource
def get_conn():
    if HF_HOST:
        os.environ["HUGGINGFACE_TOKEN"] = os.getenv("HF_TOKEN", "")
        from huggingface_hub import hf_hub_download
        local_file = hf_hub_download(
            repo_type=HF_REPO_TYPE,
            repo_id=HF_REPO_ID,
            filename=HF_FILENAME)
        return duckdb.connect(local_file)
    else:
        return duckdb.connect(DUCKDB_FILE)


title_query = """SELECT id, title FROM podcasts
    ORDER BY date DESC;
"""

query = """WITH filtered_podcasts AS (
    SELECT id
      FROM podcasts
        WHERE id in ?
),
ordered_embeddings AS (
    SELECT embeddings.id, embeddings.part, array_distance(embedding, ?::FLOAT[1024]) AS distance
    FROM embeddings
    JOIN filtered_podcasts fp ON embeddings.id = fp.id
    ORDER BY distance
      LIMIT 10
)
SELECT
    p.title,
    p.date,
    e.start,
    e.text,
    e.part,
    p.audio,
    oe.distance,
  FROM
      ordered_embeddings oe
  JOIN
      episodes e
    ON
      oe.id = e.id AND oe.part = e.part
  JOIN
      podcasts p
    ON
      oe.id = p.id
  ORDER BY oe.distance;
"""

st.title("terapyon cannel search")

conn = get_conn()
titles = conn.execute(title_query).df()
selected_title: list[str] | None = st.multiselect("Select title", titles["title"])
if selected_title:
    selected_ids = titles.loc[titles.loc[:, "title"].isin(selected_title), "id"].tolist()
else:
    st.write("All titles")
    selected_ids = titles.loc[:, "id"].tolist()

word = st.text_input("Search word")
if word:
    st.write(f"Search word: {word}")
    embeddings = get_embeddings([word], query=True)
    word_embedding = embeddings[0, :]

    result = conn.execute(query,
                          (selected_ids, word_embedding,)).df()
    selected = st.dataframe(result,
                            column_order=["title", "date", "part", "start", "distance", "text", "audio"],
                            on_select="rerun",
                            selection_mode="single-row")
    if selected:
        show_audio = False
        rows = selected["selection"].get("rows")
        if rows:
            row = rows[0]
            text = result.iloc[row, 3]
            start = result.iloc[row, 2].astype(float)
            start_delta = timedelta(seconds=start)
            if st.button("オーディオを再生"):
                show_audio = True
            if show_audio:
                st.write(f"Start time: {str(start_delta)}")
                st.audio(result.iloc[row, 5], start_time=start-5.0)
            st.text(text)