terapyon commited on
Commit
1ce86c7
·
1 Parent(s): 2b32e82

modify getting AI model for cache

Browse files
Files changed (1) hide show
  1. src/embedding.py +7 -1
src/embedding.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from sentence_transformers import SentenceTransformer
3
 
@@ -5,7 +6,11 @@ MODEL_NAME = "cl-nagoya/ruri-large"
5
  PREFIX_QUERY = "クエリ: " # "query: "
6
  PASSAGE_QUERY = "文章: " # "passage: "
7
 
8
- model = SentenceTransformer(MODEL_NAME)
 
 
 
 
9
 
10
 
11
  def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
@@ -14,6 +19,7 @@ def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
14
  if passage:
15
  texts = [PASSAGE_QUERY + text for text in texts]
16
  # texts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
 
17
  embeddings = model.encode(texts)
18
  # print(embeddings.shape)
19
  # print(type(embeddings))
 
1
+ import streamlit as st
2
  import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
 
 
6
  PREFIX_QUERY = "クエリ: " # "query: "
7
  PASSAGE_QUERY = "文章: " # "passage: "
8
 
9
+
10
+ @st.cache_resource
11
+ def get_sentence_model():
12
+ model = SentenceTransformer(MODEL_NAME)
13
+ return model
14
 
15
 
16
  def get_embeddings(texts: list[str], query=False, passage=False) -> np.ndarray:
 
19
  if passage:
20
  texts = [PASSAGE_QUERY + text for text in texts]
21
  # texts = [text[i : i + CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
22
+ model = get_sentence_model()
23
  embeddings = model.encode(texts)
24
  # print(embeddings.shape)
25
  # print(type(embeddings))