Irakoze commited on
Commit
001bd50
·
verified ·
1 Parent(s): 5a9dc8f

added first files

Browse files
Files changed (3) hide show
  1. app.py +445 -0
  2. requirements.txt +9 -0
  3. translation_service.py +51 -0
app.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import logging
4
+ import asyncio
5
+ from dotenv import load_dotenv
6
+ from langchain.prompts import PromptTemplate
7
+ from langchain_qdrant import QdrantVectorStore
8
+ from langchain.chains import RetrievalQA
9
+ from langchain_groq import ChatGroq
10
+ from qdrant_client.models import PointStruct, VectorParams, Distance
11
+ import uuid
12
+ from qdrant_client.http import models
13
+ from datetime import datetime
14
+ from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
15
+ from qdrant_client import QdrantClient
16
+ import cohere
17
+ from langchain.retrievers import ContextualCompressionRetriever
18
+ from langchain_cohere import CohereRerank
19
+ import re
20
+ from translation_service import TranslationService
21
+
22
+ # Load environment variables
23
+ load_dotenv()
24
+
25
+ # Initialize logging with INFO level and detailed format
26
+ logging.basicConfig(
27
+ filename='app.log',
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(levelname)s - %(message)s'
30
+ )
31
+
32
+ # Initialize services
33
+ translator = TranslationService()
34
+
35
+ def initialize_database_client():
36
+ """Initialize Qdrant client"""
37
+ try:
38
+ client = QdrantClient(
39
+ url=os.getenv("QDURL"),
40
+ api_key=os.getenv("API_KEY1"),
41
+ verify=True # Set to True if using SSL
42
+ )
43
+ logging.info("Qdrant client initialized successfully.")
44
+ return client
45
+ except Exception as e:
46
+ logging.error(f"Failed to initialize Qdrant client: {e}")
47
+ raise
48
+
49
+ def initialize_llm():
50
+ """Initialize LLM with fallback"""
51
+ try:
52
+ llm = ChatGroq(
53
+ temperature=0,
54
+ model_name="llama3-8b-8192",
55
+ api_key=os.getenv("GROQ_API_KEY")
56
+ )
57
+ logging.info("ChatGroq initialized with model llama3-8b-8192.")
58
+ return llm
59
+ except Exception as e:
60
+ logging.warning(f"Failed to initialize ChatGroq with llama3: {e}. Falling back to mixtral.")
61
+ try:
62
+ llm = ChatGroq(
63
+ temperature=0,
64
+ model_name="mixtral-8x7b-32768",
65
+ api_key=os.getenv("GROQ_API_KEY")
66
+ )
67
+ logging.info("ChatGroq initialized with fallback model mixtral-8x7b-32768.")
68
+ return llm
69
+ except Exception as fallback_e:
70
+ logging.error(f"Failed to initialize fallback LLM: {fallback_e}")
71
+ raise
72
+
73
+ def initialize_services():
74
+ """Initialize all services"""
75
+ try:
76
+ # Initialize Qdrant client
77
+ client = initialize_database_client()
78
+
79
+ # Initialize embeddings
80
+ embeddings = FastEmbedEmbeddings(model_name="nomic-ai/nomic-embed-text-v1.5-Q")
81
+ logging.info("FastEmbedEmbeddings initialized successfully.")
82
+
83
+ # Initialize Qdrant DB
84
+ db = QdrantVectorStore(
85
+ client=client,
86
+ embedding=embeddings,
87
+ collection_name="RR3"
88
+ )
89
+ logging.info("QdrantVectorStore initialized with collection 'RR3'.")
90
+
91
+ # Initialize retriever with reranker
92
+ cohere_client = cohere.Client(api_key=os.getenv("COHERE_API_KEY"))
93
+ reranker = CohereRerank(
94
+ client=cohere_client,
95
+ top_n=3,
96
+ model="rerank-multilingual-v3.0"
97
+ )
98
+ base_retriever = db.as_retriever(search_kwargs={"k": 14})
99
+ retriever = ContextualCompressionRetriever(
100
+ base_compressor=reranker,
101
+ base_retriever=base_retriever
102
+ )
103
+ logging.info("Retriever with reranker initialized successfully.")
104
+
105
+ # Initialize LLM
106
+ llm = initialize_llm()
107
+
108
+ return retriever, llm
109
+ except Exception as e:
110
+ logging.error(f"Service initialization error: {str(e)}")
111
+ raise
112
+
113
+ def initialize_feedback_collection():
114
+ """Initialize and verify feedback collection"""
115
+ try:
116
+ client = initialize_database_client()
117
+
118
+ # Check if collection exists
119
+ collections = client.get_collections().collections
120
+ collection_exists = any(c.name == "chat_feedback" for c in collections)
121
+
122
+ if not collection_exists:
123
+ # Create collection with proper configuration
124
+ client.create_collection(
125
+ collection_name="chat_feedback",
126
+ vectors_config=VectorParams(
127
+ size=768, # Ensure this matches the embedding size
128
+ distance=Distance.COSINE
129
+ )
130
+ )
131
+ logging.info("Created 'chat_feedback' collection with vector size 768 and Cosine distance.")
132
+ else:
133
+ logging.info("'chat_feedback' collection already exists.")
134
+
135
+ # Verify collection exists and has correct configuration
136
+ collection_info = client.get_collection("chat_feedback")
137
+ if collection_info.config.params.vectors.size != 768:
138
+ raise ValueError("Incorrect vector size in 'chat_feedback' collection.")
139
+ logging.info("'chat_feedback' collection verified successfully with correct vector size.")
140
+
141
+ return True
142
+ except Exception as e:
143
+ logging.error(f"Failed to initialize feedback collection: {e}")
144
+ raise
145
+
146
+ async def submit_feedback(feedback_type, chat_history, language_choice):
147
+ """Submit feedback with improved error handling and logging."""
148
+ try:
149
+ if not chat_history or len(chat_history) < 2:
150
+ logging.warning("Attempted to submit feedback with insufficient chat history.")
151
+ return "No recent interaction to provide feedback for."
152
+
153
+ # Get last question and answer
154
+ last_interaction = chat_history[-2:]
155
+ question = last_interaction[0].get("content", "").strip()
156
+ answer = last_interaction[1].get("content", "").strip()
157
+
158
+ if not question or not answer:
159
+ logging.warning("Question or answer content is missing.")
160
+ return "Incomplete interaction data. Cannot submit feedback."
161
+
162
+ logging.info(f"Processing feedback for question: {question[:50]}...")
163
+
164
+ # Initialize client
165
+ client = initialize_database_client()
166
+
167
+ # Create point ID
168
+ point_id = str(uuid.uuid4())
169
+
170
+ # Create payload
171
+ payload = {
172
+ "question": question,
173
+ "answer": answer,
174
+ "language": language_choice,
175
+ "timestamp": datetime.utcnow().isoformat(),
176
+ "feedback": feedback_type
177
+ }
178
+
179
+ # Initialize embeddings
180
+ embeddings = FastEmbedEmbeddings(model_name="nomic-ai/nomic-embed-text-v1.5-Q")
181
+
182
+ # Create embeddings for the Q&A pair
183
+ try:
184
+ embedding_text = f"{question} {answer}"
185
+ vector = await asyncio.to_thread(embeddings.embed_query, embedding_text)
186
+ logging.info(f"Generated embedding vector of length {len(vector)}.")
187
+ except Exception as embed_error:
188
+ logging.error(f"Embedding generation failed: {embed_error}")
189
+ return "Failed to generate embeddings for your feedback."
190
+
191
+ if not isinstance(vector, list) or not vector:
192
+ logging.error("Invalid vector generated from embeddings.")
193
+ return "Failed to generate valid embeddings for your feedback."
194
+
195
+ # Create point
196
+ point = PointStruct(
197
+ id=point_id,
198
+ payload=payload,
199
+ vector=vector
200
+ )
201
+
202
+ # Store in Qdrant
203
+ try:
204
+ operation_info = await asyncio.to_thread(
205
+ client.upsert,
206
+ collection_name="chat_feedback",
207
+ points=[point]
208
+ )
209
+ logging.info(f"Feedback submitted successfully: {point_id}")
210
+ return "Thanks for your feedback! Your response has been recorded."
211
+ except Exception as db_error:
212
+ logging.error(f"Failed to upsert point to Qdrant: {db_error}")
213
+ return "Sorry, there was an error submitting your feedback."
214
+
215
+ except Exception as e:
216
+ logging.error(f"Unexpected error in submit_feedback: {e}")
217
+ return "Sorry, there was an unexpected error submitting your feedback."
218
+
219
+ # Initialize services and feedback collection
220
+ try:
221
+ retriever, llm = initialize_services()
222
+ initialize_feedback_collection()
223
+ except Exception as initialization_error:
224
+ logging.critical(f"Initialization failed: {initialization_error}")
225
+ raise
226
+
227
+ # Prompt template
228
+ prompt_template = PromptTemplate(
229
+ template="""You are RRA Assistant, created by Cedric to help users get tax related information in Rwanda. Your task is to answer tax-related questions using the provided context.
230
+
231
+ Context: {context}
232
+
233
+ User's Question: {question}
234
+
235
+ Please follow these steps to answer the question:
236
+
237
+ Step 1: Analyze the question
238
+ Briefly explain your understanding of the question and any key points to address. If it is hi or hello, skip to step 3 and respond with a greeting.
239
+
240
+ Step 2: Provide relevant information
241
+ Using the context provided, give detailed information related to the question. Include specific facts, figures, or explanations from the context.
242
+
243
+ Step 3: Final answer
244
+ Provide a clear, concise answer to the original question. Start directly with the relevant information, avoiding phrases like "In summary" or "To conclude".
245
+
246
+ Remember:
247
+ - If you don't know the answer or can't find relevant information in the context, say so honestly.
248
+ - Do not make up information.
249
+ - Use the provided context to support your answer.
250
+ - Include "For more information, call 3004" at the end of every answer.
251
+
252
+ Your response:
253
+ """,
254
+ input_variables=['context', 'question']
255
+ )
256
+
257
+ async def process_query(message: str, language: str, chat_history: list) -> str:
258
+ try:
259
+ # Handle translation based on selected language
260
+ if language == "Kinyarwanda":
261
+ query = translator.translate(message, "rw", "en")
262
+ logging.info(f"Translated query to English: {query}")
263
+ else:
264
+ query = message
265
+
266
+ # Create QA chain
267
+ qa = RetrievalQA.from_chain_type(
268
+ llm=llm,
269
+ chain_type="stuff",
270
+ retriever=retriever,
271
+ chain_type_kwargs={"prompt": prompt_template},
272
+ return_source_documents=True
273
+ )
274
+
275
+ # Get response
276
+ response = await asyncio.to_thread(
277
+ lambda: qa.invoke({"query": query})
278
+ )
279
+ logging.info("QA chain invoked successfully.")
280
+
281
+ # Extract final answer
282
+ result_text = response.get('result', '')
283
+ final_answer_start = result_text.find("Step 3: Final answer")
284
+ if final_answer_start != -1:
285
+ answer = result_text[final_answer_start + len("Step 3: Final answer"):].strip()
286
+ else:
287
+ answer = result_text
288
+
289
+ # Clean up the answer
290
+ answer = re.sub(r'\*\*', '', answer).strip()
291
+ answer = re.sub(r'Step \d+:', '', answer).strip()
292
+
293
+ # Translate response if needed
294
+ if language == "Kinyarwanda":
295
+ answer = translator.translate(answer, "en", "rw")
296
+ logging.info(f"Translated answer to Kinyarwanda: {answer}")
297
+
298
+ return answer
299
+ except Exception as e:
300
+ logging.error(f"Query processing error: {str(e)}")
301
+ return f"An error occurred: {str(e)}"
302
+
303
+ # Define separate feedback submission functions to pass feedback type correctly
304
+ async def submit_positive_feedback(chat_history, language_choice):
305
+ return await submit_feedback("positive", chat_history, language_choice)
306
+
307
+ async def submit_negative_feedback(chat_history, language_choice):
308
+ return await submit_feedback("negative", chat_history, language_choice)
309
+
310
+ # Create Gradio interface
311
+ with gr.Blocks(title="RRA FAQ Chatbot") as demo:
312
+ gr.Markdown(
313
+ """
314
+ # RRA FAQ Chatbot
315
+ Ask tax-related questions in English or Kinyarwanda
316
+ > 🔒 Your questions and interactions remain private unless you choose to submit feedback, which helps improve our service.
317
+ """
318
+ )
319
+
320
+ # Add language selector
321
+ language = gr.Radio(
322
+ choices=["English", "Kinyarwanda"],
323
+ value="English",
324
+ label="Select Language / Hitamo Ururimi"
325
+ )
326
+
327
+ chatbot = gr.Chatbot(
328
+ value=[],
329
+ show_label=False,
330
+ height=400,
331
+ type='messages'
332
+ )
333
+
334
+ with gr.Row():
335
+ msg = gr.Textbox(
336
+ label="Ask your question",
337
+ placeholder="Type your tax-related question here...",
338
+ show_label=False
339
+ )
340
+ submit = gr.Button("Send")
341
+
342
+ # Add feedback section
343
+ with gr.Row():
344
+ with gr.Column(scale=2):
345
+ feedback_label = gr.Markdown("Was this response helpful?")
346
+ with gr.Column(scale=1):
347
+ feedback_positive = gr.Button("👍 Helpful")
348
+ with gr.Column(scale=1):
349
+ feedback_negative = gr.Button("👎 Not Helpful")
350
+
351
+ # Add feedback status message
352
+ feedback_status = gr.Markdown("")
353
+
354
+ # Connect feedback buttons to their respective functions
355
+ feedback_positive.click(
356
+ fn=submit_positive_feedback,
357
+ inputs=[chatbot, language],
358
+ outputs=feedback_status
359
+ )
360
+
361
+ feedback_negative.click(
362
+ fn=submit_negative_feedback,
363
+ inputs=[chatbot, language],
364
+ outputs=feedback_status
365
+ )
366
+
367
+ # Create two sets of examples
368
+ with gr.Row() as english_examples_row:
369
+ gr.Examples(
370
+ examples=[
371
+ "What is VAT in Rwanda?",
372
+ "How do I register for taxes?",
373
+ "What are the tax payment deadlines?",
374
+ "How can I get a TIN number?",
375
+ "How do I get purchase code?"
376
+ ],
377
+ inputs=msg,
378
+ label="English Examples"
379
+ )
380
+
381
+ with gr.Row(visible=False) as kinyarwanda_examples_row:
382
+ gr.Examples(
383
+ examples=[
384
+ "Ese VAT ni iki mu Rwanda?",
385
+ "Nabona TIN number nte?",
386
+ "Ni ryari tugomba kwishyura imisoro?",
387
+ "Ese nandikwa nte ku musoro?",
388
+ "Ni gute nabone kode yo kugura?"
389
+ ],
390
+ inputs=msg,
391
+ label="Kinyarwanda Examples"
392
+ )
393
+
394
+ async def respond(message, lang, chat_history):
395
+ bot_message = await process_query(message, lang, chat_history)
396
+ chat_history.append({"role": "user", "content": message})
397
+ chat_history.append({"role": "assistant", "content": bot_message})
398
+ return "", chat_history
399
+
400
+ def toggle_language_interface(language_choice):
401
+ if language_choice == "English":
402
+ placeholder_text = "Type your tax-related question here..."
403
+ return {
404
+ msg: gr.update(placeholder=placeholder_text),
405
+ english_examples_row: gr.update(visible=True),
406
+ kinyarwanda_examples_row: gr.update(visible=False)
407
+ }
408
+ else:
409
+ placeholder_text = "Andika ibibazo bijyanye n'umusoro hano"
410
+ return {
411
+ msg: gr.update(placeholder=placeholder_text),
412
+ english_examples_row: gr.update(visible=False),
413
+ kinyarwanda_examples_row: gr.update(visible=True)
414
+ }
415
+
416
+ msg.submit(respond, [msg, language, chatbot], [msg, chatbot])
417
+ submit.click(respond, [msg, language, chatbot], [msg, chatbot])
418
+
419
+ # Update both examples visibility and placeholder when language changes
420
+ language.change(
421
+ fn=toggle_language_interface,
422
+ inputs=language,
423
+ outputs=[msg, english_examples_row, kinyarwanda_examples_row]
424
+ )
425
+
426
+ gr.Markdown(
427
+ """
428
+ ### About
429
+ - Created by: [Cedric](mailto:[email protected])
430
+ - Data source: [RRA Website FAQ](https://www.rra.gov.rw/en/domestic-tax-services/faqs)
431
+
432
+ **Disclaimer:** This chatbot provides general tax information. For official guidance,
433
+ consult RRA or call 3004.
434
+ 🔒 **Privacy:** Your interactions remain private unless you choose to submit feedback.
435
+ """
436
+ )
437
+
438
+ # Launch the app
439
+ if __name__ == "__main__":
440
+ try:
441
+ demo.launch(share=False)
442
+ logging.info("Gradio app launched successfully.")
443
+ except Exception as launch_error:
444
+ logging.critical(f"Failed to launch Gradio app: {launch_error}")
445
+ raise
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.14.0
2
+ python-dotenv>=1.0.0
3
+ langchain>=0.1.0
4
+ langchain-community>=0.0.13
5
+ langchain-groq>=0.1.1
6
+ cohere>=4.37
7
+ qdrant-client>=1.7.0
8
+ requests>=2.31.0
9
+ fastembeddings>=0.0.11
translation_service.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import requests
4
+ from pydantic import BaseModel
5
+ from dotenv import load_dotenv
6
+
7
+ # Load environment variables
8
+ load_dotenv()
9
+
10
+ class TranslationRequest(BaseModel):
11
+ src: str
12
+ tgt: str
13
+ use_multi: str
14
+ text: str
15
+
16
+ class Config:
17
+ populate_by_name = True
18
+
19
+ class TranslationService:
20
+ def __init__(self):
21
+ self.api_url = os.getenv('TRANSLATION_API_URL')
22
+ if not self.api_url:
23
+ raise ValueError("TRANSLATION_API_URL environment variable is not set")
24
+
25
+ def translate(self, text: str, src_language: str, tgt_language: str) -> str:
26
+ try:
27
+ payload = TranslationRequest(
28
+ src=src_language,
29
+ tgt=tgt_language,
30
+ use_multi="MULTI",
31
+ text=text
32
+ )
33
+
34
+ response = requests.post(
35
+ self.api_url,
36
+ headers={
37
+ "accept": "application/json",
38
+ "Content-Type": "application/json"
39
+ },
40
+ json=payload.model_dump()
41
+ )
42
+
43
+ if response.status_code == 200:
44
+ return response.json().get("translation")
45
+ elif response.status_code == 406:
46
+ raise ValueError("Invalid language pair selected")
47
+ else:
48
+ raise ValueError(f"Translation failed with status code {response.status_code}")
49
+ except Exception as e:
50
+ logging.error(f"Translation error: {str(e)}")
51
+ return text