Louxads commited on
Commit
ff77a67
·
verified ·
1 Parent(s): 8a10951

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -4
app.py CHANGED
@@ -1,9 +1,116 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return f"Hello {name}!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  if __name__ == "__main__":
9
- iface.launch()
 
1
+ import os
2
+ import time
3
  import gradio as gr
4
+ from langchain_community.vectorstores import FAISS
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain.memory import ConversationBufferWindowMemory
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ import torch
11
 
12
+ # Initialize embeddings and FAISS database
13
+ def load_embeddings():
14
+ return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
15
 
16
+ def load_faiss_db():
17
+ embeddings = load_embeddings()
18
+ return FAISS.load_local("ipc_embed_db", embeddings, allow_dangerous_deserialization=True)
19
+
20
+ # Load embeddings and FAISS database
21
+ embeddings = load_embeddings()
22
+ db = load_faiss_db()
23
+ db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
24
+
25
+ # Define prompt template
26
+ prompt_template = """
27
+ <s>[INST]
28
+ As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses. Ensure your answers meet these criteria:
29
+ - Respond in a bullet-point format to clearly delineate distinct aspects of the legal query.
30
+ - Each point should accurately reflect the breadth of the legal provision in question, avoiding over-specificity unless directly relevant to the user's query.
31
+ - Clarify the general applicability of the legal rules or sections mentioned, highlighting any common misconceptions or frequently misunderstood aspects.
32
+ - Limit responses to essential information that directly addresses the user's question, providing concise yet comprehensive explanations.
33
+ - Avoid assuming specific contexts or details not provided in the query, focusing on delivering universally applicable legal interpretations unless otherwise specified.
34
+ - Conclude with a brief summary that captures the essence of the legal discussion and corrects any common misinterpretations related to the topic.
35
+
36
+ CONTEXT: {context}
37
+ CHAT HISTORY: {chat_history}
38
+ QUESTION: {question}
39
+ ANSWER:
40
+ - [Detail the first key aspect of the law, ensuring it reflects general application]
41
+ - [Provide a concise explanation of how the law is typically interpreted or applied]
42
+ - [Correct a common misconception or clarify a frequently misunderstood aspect]
43
+ - [Detail any exceptions to the general rule, if applicable]
44
+ - [Include any additional relevant information that directly relates to the user's query]
45
+ </s>[INST]
46
+ """
47
+
48
+ prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history'])
49
+
50
+ # Load the InLegalBERT model and tokenizer
51
+ tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
52
+ model = AutoModelForSequenceClassification.from_pretrained("law-ai/InLegalBERT")
53
+
54
+ # Function to get the model's response
55
+ def get_inlegalbert_response(question):
56
+ inputs = tokenizer(question, return_tensors="pt")
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+ logits = outputs.logits
60
+ response = tokenizer.decode(torch.argmax(logits, dim=-1))
61
+ return response
62
+
63
+ # Define a wrapper for the model
64
+ class InLegalBERTWrapper:
65
+ def __init__(self, model, tokenizer):
66
+ self.model = model
67
+ self.tokenizer = tokenizer
68
+
69
+ def __call__(self, prompt, **kwargs):
70
+ return {"text": get_inlegalbert_response(prompt)}
71
+
72
+ llm = InLegalBERTWrapper(model, tokenizer)
73
+
74
+ qa = ConversationalRetrievalChain.from_llm(
75
+ llm=llm,
76
+ memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True),
77
+ retriever=db_retriever,
78
+ combine_docs_chain_kwargs={'prompt': prompt}
79
+ )
80
+
81
+ def extract_answer(full_response):
82
+ answer_start = full_response.find("Response:")
83
+ if answer_start != -1:
84
+ answer_start += len("Response:")
85
+ return full_response[answer_start:].strip()
86
+ return full_response
87
+
88
+ def chat(input_prompt, messages):
89
+ if "messages" not in messages:
90
+ messages["messages"] = []
91
+
92
+ messages["messages"].append({"role": "user", "content": input_prompt})
93
+
94
+ result = qa.invoke(input=input_prompt)
95
+ answer = extract_answer(result["answer"])
96
+
97
+ messages["messages"].append({"role": "assistant", "content": answer})
98
+
99
+ return [(message["role"], message["content"]) for message in messages["messages"]], messages
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Stat.ai Legal Assistant")
103
+ chatbot = gr.Chatbot()
104
+ state = gr.State({"messages": []})
105
+ msg = gr.Textbox(placeholder="Ask Stat.ai")
106
+
107
+ def user_input(message, history):
108
+ history["messages"].append({"role": "user", "content": message})
109
+ return "", history
110
+
111
+ msg.submit(user_input, [msg, state], [msg, state], queue=False).then(
112
+ chat, [msg, state], [chatbot, state]
113
+ )
114
 
115
  if __name__ == "__main__":
116
+ demo.launch()