viboognesh commited on
Commit
ad7ff16
·
verified ·
1 Parent(s): bee4dc1

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +17 -0
  2. main.py +174 -0
  3. requirements.txt +11 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the official Python 3.12 image
2
+ FROM python:3.12
3
+
4
+ # Set the working directory to /app
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
+
13
+ # Make port 7860 available to the world outside this container
14
+ EXPOSE 7860
15
+
16
+ # Run main.py when the container launches
17
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Depends
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List
4
+ import os
5
+ import aiofiles
6
+ import uuid
7
+ import shutil
8
+
9
+ # from dotenv import load_dotenv
10
+
11
+ from langchain_community.document_loaders import TextLoader, Docx2txtLoader, PyPDFLoader
12
+ from langchain.prompts import ChatPromptTemplate, PromptTemplate
13
+ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate
14
+ from langchain_community.document_loaders.csv_loader import CSVLoader
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain.memory import ConversationBufferMemory
17
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
+ from langchain_community.vectorstores import Chroma
19
+ from langchain.chains import ConversationalRetrievalChain
20
+
21
+ # load_dotenv()
22
+
23
+ app = FastAPI()
24
+
25
+ origins = ["https://viboognesh-react-chat.static.hf.space"]
26
+
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=origins,
30
+ allow_credentials=True,
31
+ allow_methods=["GET", "POST"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+
36
+ class ConversationChainManager:
37
+ def __init__(self):
38
+ self.conversation_chain = None
39
+ self.llm_model = ChatOpenAI()
40
+ self.embeddings = OpenAIEmbeddings()
41
+
42
+ def create_conversational_chain(self, file_paths: List[str], session_id: str):
43
+ docs = self.get_docs(file_paths)
44
+ memory = ConversationBufferMemory(
45
+ memory_key="chat_history", return_messages=True
46
+ )
47
+ vectordb = Chroma.from_documents(
48
+ docs,
49
+ self.embeddings,
50
+ collection_name=session_id,
51
+ persist_directory="./chroma_db",
52
+ )
53
+ retriever = vectordb.as_retriever()
54
+ self.conversation_chain = ConversationalRetrievalChain.from_llm(
55
+ llm=self.llm_model,
56
+ retriever=retriever,
57
+ condense_question_prompt=self.get_question_generator_prompt(),
58
+ combine_docs_chain_kwargs={
59
+ "document_prompt": self.get_document_prompt(),
60
+ "prompt": self.get_final_prompt(),
61
+ },
62
+ memory=memory,
63
+ )
64
+
65
+ @staticmethod
66
+ def get_docs(file_paths: List[str]) -> List:
67
+ docs = []
68
+ for file_path in file_paths:
69
+ if file_path.endswith(".txt"):
70
+ loader = TextLoader(file_path)
71
+ document = loader.load()
72
+ splitter = RecursiveCharacterTextSplitter(
73
+ chunk_size=1000, chunk_overlap=100
74
+ )
75
+ txt_documents = splitter.split_documents(document)
76
+ docs.extend(txt_documents)
77
+ elif file_path.endswith(".csv"):
78
+ loader = CSVLoader(file_path)
79
+ csv_documents = loader.load()
80
+ docs.extend(csv_documents)
81
+ elif file_path.endswith(".docx"):
82
+ loader = Docx2txtLoader(file_path)
83
+ document = loader.load()
84
+ splitter = RecursiveCharacterTextSplitter(
85
+ chunk_size=1000, chunk_overlap=100
86
+ )
87
+ docx_documents = splitter.split_documents(document)
88
+ docs.extend(docx_documents)
89
+ elif file_path.endswith(".pdf"):
90
+ loader = PyPDFLoader(file_path)
91
+ pdf_documents = loader.load_and_split()
92
+ docs.extend(pdf_documents)
93
+ return docs
94
+
95
+ @staticmethod
96
+ def get_document_prompt() -> PromptTemplate:
97
+ document_template = """Document Content:{page_content}
98
+ Document Path: {source}"""
99
+ return PromptTemplate(
100
+ input_variables=["page_content", "source"],
101
+ template=document_template,
102
+ )
103
+
104
+ @staticmethod
105
+ def get_question_generator_prompt() -> PromptTemplate:
106
+ question_generator_template = """Combine the chat history and follow up question into
107
+ a standalone question.\n Chat History: {chat_history}\n
108
+ Follow up question: {question}
109
+ """
110
+ return PromptTemplate.from_template(question_generator_template)
111
+
112
+ @staticmethod
113
+ def get_final_prompt() -> ChatPromptTemplate:
114
+ final_prompt_template = """Answer question based on the context and chat_history.
115
+ If you cannot find answers, ask more related questions from the user.
116
+ Use only the basename of the file path as name of the documents.
117
+ Mention document name of the documents you used in your answer.
118
+
119
+ context:
120
+ {context}
121
+
122
+ chat_history:
123
+ {chat_history}
124
+
125
+ question:
126
+ {question}
127
+
128
+ Answer:
129
+ """
130
+
131
+ messages = [
132
+ SystemMessagePromptTemplate.from_template(final_prompt_template),
133
+ HumanMessagePromptTemplate.from_template("{question}"),
134
+ ]
135
+
136
+ return ChatPromptTemplate.from_messages(messages)
137
+
138
+
139
+ @app.post("/upload_files/")
140
+ async def upload_files(
141
+ files: List[UploadFile] = File(...),
142
+ conversation_chain_manager: ConversationChainManager = Depends(),
143
+ ):
144
+ session_id = str(uuid.uuid4())
145
+ session_folder = f"uploads/{session_id}"
146
+ os.makedirs(session_folder, exist_ok=True)
147
+ file_paths = []
148
+ for file in files:
149
+ file_path = f"{session_folder}/{file.filename}"
150
+ async with aiofiles.open(file_path, "wb") as out_file:
151
+ content = await file.read()
152
+ await out_file.write(content)
153
+ file_paths.append(file_path)
154
+
155
+ conversation_chain_manager.create_conversational_chain(file_paths, session_id)
156
+ shutil.rmtree(session_folder)
157
+ print("conversational_chain_manager created")
158
+ return {"message": "ConversationalRetrievalChain is created. Please ask questions."}
159
+
160
+
161
+ @app.get("/predict/")
162
+ async def predict(
163
+ query: str, conversation_chain_manager: ConversationChainManager = Depends()
164
+ ):
165
+ if conversation_chain_manager.conversation_chain is None:
166
+ system_prompt = "Answer the question and also ask the user to upload files to ask questions from the files.\n"
167
+ response = conversation_chain_manager.llm_model.invoke(system_prompt + query)
168
+ answer = response.content
169
+ else:
170
+ response = conversation_chain_manager.conversation_chain.invoke(query)
171
+ answer = response["answer"]
172
+
173
+ print("predict called")
174
+ return {"answer": answer}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ sqlalchemy
4
+ langchain_community
5
+ langchain
6
+ pypdf
7
+ langchain_openai
8
+ python-dotenv
9
+ python-multipart
10
+ chromadb
11
+ aiofiles