terapyon commited on
Commit
eeafe79
·
unverified ·
2 Parent(s): 7f1680d ee7e0ce

Merge pull request #4 from terapyon/terada/mt-241-streamlit-ui

Browse files
Files changed (4) hide show
  1. pyproject.toml +1 -0
  2. requirements.txt +2 -1
  3. src/app.py +74 -0
  4. src/embedding.py +7 -1
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ dependencies = [
14
  "pyarrow>=18.1.0",
15
  "sentence-transformers>=3.3.1",
16
  "sentencepiece>=0.2.0",
 
17
  "torch>=2.5.1",
18
  "tqdm>=4.67.1",
19
  "unidic-lite>=1.0.8",
 
14
  "pyarrow>=18.1.0",
15
  "sentence-transformers>=3.3.1",
16
  "sentencepiece>=0.2.0",
17
+ "streamlit>=1.41.1",
18
  "torch>=2.5.1",
19
  "tqdm>=4.67.1",
20
  "unidic-lite>=1.0.8",
requirements.txt CHANGED
@@ -9,4 +9,5 @@ pandas
9
  numpy
10
  polars
11
  pyarrow
12
- duckdb
 
 
9
  numpy
10
  polars
11
  pyarrow
12
+ duckdb
13
+ streamlit
src/app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import duckdb
3
+ from embedding import get_embeddings
4
+ from config import DUCKDB_FILE
5
+
6
+
7
+ @st.cache_resource
8
+ def get_conn():
9
+ return duckdb.connect(DUCKDB_FILE)
10
+
11
+
12
+ title_query = """SELECT id, title FROM podcasts
13
+ ORDER BY date DESC;
14
+ """
15
+
16
+ query = """WITH filtered_podcasts AS (
17
+ SELECT id
18
+ FROM podcasts
19
+ WHERE id in ?
20
+ ),
21
+ ordered_embeddings AS (
22
+ SELECT embeddings.id, embeddings.part
23
+ FROM embeddings
24
+ JOIN filtered_podcasts fp ON embeddings.id = fp.id
25
+ ORDER BY array_distance(embedding, ?::FLOAT[1024])
26
+ LIMIT 10
27
+ )
28
+ SELECT
29
+ p.title,
30
+ p.date,
31
+ e.start,
32
+ e.text,
33
+ e.part,
34
+ p.audio,
35
+ FROM
36
+ ordered_embeddings oe
37
+ JOIN
38
+ episodes e
39
+ ON
40
+ oe.id = e.id AND oe.part = e.part
41
+ JOIN
42
+ podcasts p
43
+ ON
44
+ oe.id = p.id;
45
+ """
46
+
47
+ st.title("terapyon cannel search")
48
+
49
+ conn = get_conn()
50
+ titles = conn.execute(title_query).df()
51
+ selected_title: list[str] | None = st.multiselect("Select title", titles["title"])
52
+ if selected_title:
53
+ selected_ids = titles.loc[titles.loc[:, "title"].isin(selected_title), "id"].tolist()
54
+ else:
55
+ st.write("All titles")
56
+ selected_ids = titles.loc[:, "id"].tolist()
57
+
58
+ word = st.text_input("Search word")
59
+ if word:
60
+ st.write(f"Search word: {word}")
61
+ embeddings = get_embeddings([word], query=True)
62
+ word_embedding = embeddings[0, :]
63
+
64
+ result = conn.execute(query,
65
+ (selected_ids, word_embedding,)).df()
66
+ selected = st.dataframe(result,
67
+ column_order=["title", "date", "part", "start", "text", "audio"],
68
+ on_select="rerun",
69
+ selection_mode="single-row")
70
+ if selected:
71
+ rows = selected["selection"].get("rows")
72
+ if rows:
73
+ row = rows[0]
74
+ st.text(result.iloc[row, 3])
src/embedding.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from sentence_transformers import SentenceTransformer
3
 
@@ -5,7 +6,11 @@ MODEL_NAME = "cl-nagoya/ruri-large"
5
  PREFIX_QUERY = "クエリ: " # "query: "
6
  PASSAGE_QUERY = "文章: " # "passage: "
7
 
8
- model = SentenceTransformer(MODEL_NAME)
 
 
 
 
9
 
10
 
11
  def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
@@ -14,6 +19,7 @@ def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
14
  if passage:
15
  texts = [PASSAGE_QUERY + text for text in texts]
16
  # texts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
 
17
  embeddings = model.encode(texts)
18
  # print(embeddings.shape)
19
  # print(type(embeddings))
 
1
+ import streamlit as st
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
 
 
6
  PREFIX_QUERY = "クエリ: " # "query: "
7
  PASSAGE_QUERY = "文章: " # "passage: "
8
 
9
+
10
+ @st.cache_resource
11
+ def get_sentence_model():
12
+ model = SentenceTransformer(MODEL_NAME)
13
+ return model
14
 
15
 
16
  def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
 
19
  if passage:
20
  texts = [PASSAGE_QUERY + text for text in texts]
21
  # texts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
22
+ model = get_sentence_model()
23
  embeddings = model.encode(texts)
24
  # print(embeddings.shape)
25
  # print(type(embeddings))