Louxads commited on
Commit
0aafe72
·
verified ·
1 Parent(s): ff77a67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -114
app.py CHANGED
@@ -1,116 +1,3 @@
1
- import os
2
- import time
3
  import gradio as gr
4
- from langchain_community.vectorstores import FAISS
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
- from langchain.prompts import PromptTemplate
7
- from langchain.memory import ConversationBufferWindowMemory
8
- from langchain.chains import ConversationalRetrievalChain
9
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
- import torch
11
 
12
- # Initialize embeddings and FAISS database
13
- def load_embeddings():
14
- return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
15
-
16
- def load_faiss_db():
17
- embeddings = load_embeddings()
18
- return FAISS.load_local("ipc_embed_db", embeddings, allow_dangerous_deserialization=True)
19
-
20
- # Load embeddings and FAISS database
21
- embeddings = load_embeddings()
22
- db = load_faiss_db()
23
- db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
24
-
25
- # Define prompt template
26
- prompt_template = """
27
- <s>[INST]
28
- As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses. Ensure your answers meet these criteria:
29
- - Respond in a bullet-point format to clearly delineate distinct aspects of the legal query.
30
- - Each point should accurately reflect the breadth of the legal provision in question, avoiding over-specificity unless directly relevant to the user's query.
31
- - Clarify the general applicability of the legal rules or sections mentioned, highlighting any common misconceptions or frequently misunderstood aspects.
32
- - Limit responses to essential information that directly addresses the user's question, providing concise yet comprehensive explanations.
33
- - Avoid assuming specific contexts or details not provided in the query, focusing on delivering universally applicable legal interpretations unless otherwise specified.
34
- - Conclude with a brief summary that captures the essence of the legal discussion and corrects any common misinterpretations related to the topic.
35
-
36
- CONTEXT: {context}
37
- CHAT HISTORY: {chat_history}
38
- QUESTION: {question}
39
- ANSWER:
40
- - [Detail the first key aspect of the law, ensuring it reflects general application]
41
- - [Provide a concise explanation of how the law is typically interpreted or applied]
42
- - [Correct a common misconception or clarify a frequently misunderstood aspect]
43
- - [Detail any exceptions to the general rule, if applicable]
44
- - [Include any additional relevant information that directly relates to the user's query]
45
- </s>[INST]
46
- """
47
-
48
- prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history'])
49
-
50
- # Load the InLegalBERT model and tokenizer
51
- tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
52
- model = AutoModelForSequenceClassification.from_pretrained("law-ai/InLegalBERT")
53
-
54
- # Function to get the model's response
55
- def get_inlegalbert_response(question):
56
- inputs = tokenizer(question, return_tensors="pt")
57
- with torch.no_grad():
58
- outputs = model(**inputs)
59
- logits = outputs.logits
60
- response = tokenizer.decode(torch.argmax(logits, dim=-1))
61
- return response
62
-
63
- # Define a wrapper for the model
64
- class InLegalBERTWrapper:
65
- def __init__(self, model, tokenizer):
66
- self.model = model
67
- self.tokenizer = tokenizer
68
-
69
- def __call__(self, prompt, **kwargs):
70
- return {"text": get_inlegalbert_response(prompt)}
71
-
72
- llm = InLegalBERTWrapper(model, tokenizer)
73
-
74
- qa = ConversationalRetrievalChain.from_llm(
75
- llm=llm,
76
- memory=ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True),
77
- retriever=db_retriever,
78
- combine_docs_chain_kwargs={'prompt': prompt}
79
- )
80
-
81
- def extract_answer(full_response):
82
- answer_start = full_response.find("Response:")
83
- if answer_start != -1:
84
- answer_start += len("Response:")
85
- return full_response[answer_start:].strip()
86
- return full_response
87
-
88
- def chat(input_prompt, messages):
89
- if "messages" not in messages:
90
- messages["messages"] = []
91
-
92
- messages["messages"].append({"role": "user", "content": input_prompt})
93
-
94
- result = qa.invoke(input=input_prompt)
95
- answer = extract_answer(result["answer"])
96
-
97
- messages["messages"].append({"role": "assistant", "content": answer})
98
-
99
- return [(message["role"], message["content"]) for message in messages["messages"]], messages
100
-
101
- with gr.Blocks() as demo:
102
- gr.Markdown("## Stat.ai Legal Assistant")
103
- chatbot = gr.Chatbot()
104
- state = gr.State({"messages": []})
105
- msg = gr.Textbox(placeholder="Ask Stat.ai")
106
-
107
- def user_input(message, history):
108
- history["messages"].append({"role": "user", "content": message})
109
- return "", history
110
-
111
- msg.submit(user_input, [msg, state], [msg, state], queue=False).then(
112
- chat, [msg, state], [chatbot, state]
113
- )
114
-
115
- if __name__ == "__main__":
116
- demo.launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
+ gr.load("models/law-ai/InLegalBERT").launch()