sakuexe commited on
Commit
0b367ea
·
1 Parent(s): 6427fd5

tweaked the code a bit to make answering faster

Browse files
Files changed (2) hide show
  1. app.py +7 -7
  2. vector_store.py +11 -11
app.py CHANGED
@@ -2,13 +2,12 @@
2
  # https://huggingface.co/learn/cookbook/rag_zephyr_langchain
3
  # langchain
4
  from typing import TypedDict
5
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
6
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough
9
  from langchain_huggingface import HuggingFacePipeline
10
  # huggingface
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
12
  from transformers import pipeline
13
  # pytorch
14
  import torch
@@ -59,6 +58,11 @@ text_generation_pipeline = pipeline(
59
 
60
  llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
61
 
 
 
 
 
 
62
 
63
  def generate_prompt(message_history: list[ChatMessage], max_history=5):
64
  # creating the prompt template in the shape of a chat prompt
@@ -99,10 +103,6 @@ def generate_prompt(message_history: list[ChatMessage], max_history=5):
99
 
100
 
101
  async def generate_answer(message_history: list[ChatMessage]):
102
- # generate a vector store
103
- print("creating the document database")
104
- db = await get_document_database("learning_material/*/*/*")
105
- print("Document database is ready")
106
 
107
  # initialize the similarity search
108
  n_of_best_results = 4
 
2
  # https://huggingface.co/learn/cookbook/rag_zephyr_langchain
3
  # langchain
4
  from typing import TypedDict
5
+ from langchain_core.prompts import ChatPromptTemplate
 
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_huggingface import HuggingFacePipeline
9
  # huggingface
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from transformers import pipeline
12
  # pytorch
13
  import torch
 
58
 
59
  llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
60
 
61
+ # generate a vector store
62
+ print("creating the document database")
63
+ db = get_document_database("learning_material/*/*/*")
64
+ print("Document database is ready")
65
+
66
 
67
  def generate_prompt(message_history: list[ChatMessage], max_history=5):
68
  # creating the prompt template in the shape of a chat prompt
 
103
 
104
 
105
  async def generate_answer(message_history: list[ChatMessage]):
 
 
 
 
106
 
107
  # initialize the similarity search
108
  n_of_best_results = 4
vector_store.py CHANGED
@@ -10,21 +10,21 @@ from glob import glob
10
  import pathlib
11
 
12
 
13
- async def load_text(file_path: str) -> list[Document] | None:
14
  """Loads text documents (.txt) asynchronously from a passed file_path."""
15
  assert file_path != ""
16
  assert pathlib.Path(file_path).suffix == ".txt"
17
 
18
  try:
19
  loader = TextLoader(file_path)
20
- return await loader.aload()
21
  except UnicodeError or RuntimeError as err:
22
  print(f"could not load file: {file_path}")
23
  print(f"error: {err}")
24
 
25
 
26
  # https://python.langchain.com/docs/how_to/document_loader_markdown/
27
- async def load_markdown(file_path: str) -> list[Document] | None:
28
  """Loads markdown files asynchronously from a passed file_path."""
29
  assert file_path != ""
30
  assert pathlib.Path(file_path).suffix == ".md"
@@ -33,33 +33,33 @@ async def load_markdown(file_path: str) -> list[Document] | None:
33
  # use the mode elements to keep metadata about if the information is
34
  # a paragraph, link or a heading for example
35
  loader = UnstructuredMarkdownLoader(file_path, mode="elements")
36
- return await loader.aload()
37
  except UnicodeError or RuntimeError as err:
38
  print(f"could not load file: {file_path}")
39
  print(f"error: {err}")
40
 
41
 
42
  # https://python.langchain.com/docs/how_to/document_loader_pdf/
43
- async def load_pdf(file_path: str) -> list[Document] | None:
44
  """Loads pdf documents (.pdf) asynchronously from a passed file_path."""
45
  assert file_path != ""
46
  assert pathlib.Path(file_path).suffix == ".pdf"
47
 
48
  loader = PyPDFLoader(file_path)
49
  try:
50
- return await loader.aload()
51
  except PyPdfError as err:
52
  print(f"could not read file: {file_path}")
53
  print(f"error: {err}")
54
 
55
 
56
- async def load_html(file_path: str) -> list[Document]:
57
  """Loads html documents (.html) asynchronously from a passed file_path."""
58
  assert file_path != ""
59
  assert pathlib.Path(file_path).suffix == ".html" or ".htm"
60
 
61
  loader = BSHTMLLoader(file_path)
62
- return await loader.aload()
63
 
64
 
65
  # hold all of the loader functions for easy 0(1) fetching
@@ -73,7 +73,7 @@ LOADER_MAP = {
73
 
74
 
75
  # https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/vectorstore/
76
- async def get_document_database(
77
  data_folder="learning_material/*/*/*",
78
  embedding_model="BAAI/bge-base-en-v1.5",
79
  chunk_size=1028, chunk_overlap=0,
@@ -96,7 +96,7 @@ async def get_document_database(
96
  continue
97
 
98
  # load the document with a filetype specific loader
99
- result_documents = await load_fn(file_path)
100
 
101
  if not result_documents:
102
  print(f"file {file_path} does not include any content, skipping")
@@ -111,7 +111,7 @@ async def get_document_database(
111
 
112
  chunked_docs = splitter.split_documents(all_docs)
113
 
114
- return await FAISS.afrom_documents(
115
  chunked_docs,
116
  HuggingFaceEmbeddings(model_name=embedding_model)
117
  )
 
10
  import pathlib
11
 
12
 
13
+ def load_text(file_path: str) -> list[Document] | None:
14
  """Loads text documents (.txt) asynchronously from a passed file_path."""
15
  assert file_path != ""
16
  assert pathlib.Path(file_path).suffix == ".txt"
17
 
18
  try:
19
  loader = TextLoader(file_path)
20
+ return loader.load()
21
  except UnicodeError or RuntimeError as err:
22
  print(f"could not load file: {file_path}")
23
  print(f"error: {err}")
24
 
25
 
26
  # https://python.langchain.com/docs/how_to/document_loader_markdown/
27
+ def load_markdown(file_path: str) -> list[Document] | None:
28
  """Loads markdown files asynchronously from a passed file_path."""
29
  assert file_path != ""
30
  assert pathlib.Path(file_path).suffix == ".md"
 
33
  # use the mode elements to keep metadata about if the information is
34
  # a paragraph, link or a heading for example
35
  loader = UnstructuredMarkdownLoader(file_path, mode="elements")
36
+ return loader.load()
37
  except UnicodeError or RuntimeError as err:
38
  print(f"could not load file: {file_path}")
39
  print(f"error: {err}")
40
 
41
 
42
  # https://python.langchain.com/docs/how_to/document_loader_pdf/
43
+ def load_pdf(file_path: str) -> list[Document] | None:
44
  """Loads pdf documents (.pdf) asynchronously from a passed file_path."""
45
  assert file_path != ""
46
  assert pathlib.Path(file_path).suffix == ".pdf"
47
 
48
  loader = PyPDFLoader(file_path)
49
  try:
50
+ return loader.load()
51
  except PyPdfError as err:
52
  print(f"could not read file: {file_path}")
53
  print(f"error: {err}")
54
 
55
 
56
+ def load_html(file_path: str) -> list[Document]:
57
  """Loads html documents (.html) asynchronously from a passed file_path."""
58
  assert file_path != ""
59
  assert pathlib.Path(file_path).suffix == ".html" or ".htm"
60
 
61
  loader = BSHTMLLoader(file_path)
62
+ return loader.load()
63
 
64
 
65
  # hold all of the loader functions for easy 0(1) fetching
 
73
 
74
 
75
  # https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/vectorstore/
76
+ def get_document_database(
77
  data_folder="learning_material/*/*/*",
78
  embedding_model="BAAI/bge-base-en-v1.5",
79
  chunk_size=1028, chunk_overlap=0,
 
96
  continue
97
 
98
  # load the document with a filetype specific loader
99
+ result_documents = load_fn(file_path)
100
 
101
  if not result_documents:
102
  print(f"file {file_path} does not include any content, skipping")
 
111
 
112
  chunked_docs = splitter.split_documents(all_docs)
113
 
114
+ return FAISS.from_documents(
115
  chunked_docs,
116
  HuggingFaceEmbeddings(model_name=embedding_model)
117
  )