File size: 9,493 Bytes
9889643 08ab464 add1f34 db180a4 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9889643 add1f34 9a573be 9889643 add1f34 9a573be 08ab464 add1f34 08ab464 9889643 add1f34 9889643 add1f34 9889643 add1f34 db180a4 94fe369 db180a4 9889643 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
import chromadb
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import requests
from itertools import combinations
import sqlite3
import pandas as pd
import os
import time
# Define FastAPI app
app = FastAPI()
origins = [
"http://localhost:5173",
"localhost:5173"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Load the model at startup
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
client = chromadb.PersistentClient(path='./chromadb')
collection = client.get_or_create_collection(name="symptomsvector")
# Helper function to initialize database and populate from CSV if needed
def init_db():
conn = sqlite3.connect("diseases_symptoms.db")
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS diseases (
id INTEGER PRIMARY KEY,
name TEXT,
symptoms TEXT,
treatments TEXT
)
''')
conn.commit()
return conn
# Populate database from CSV if it's the first time
if not os.path.exists("diseases_symptoms.db"):
conn = init_db()
df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv")
df['Symptoms'] = df['Symptoms'].str.split(',').apply(lambda x: [s.strip() for s in x])
for _, row in df.iterrows():
symptoms_str = ",".join(row['Symptoms'])
cursor = conn.cursor()
cursor.execute("INSERT INTO diseases (name, symptoms, treatments) VALUES (?, ?, ?)",
(row['Name'], symptoms_str, row.get('Treatments', '')))
conn.commit()
conn.close()
class SymptomQuery(BaseModel):
symptom: str
# Helper function to fetch diseases matching symptoms from SQLite
def fetch_diseases_by_symptoms(matching_symptoms):
conn = sqlite3.connect("diseases_symptoms.db")
cursor = conn.cursor()
disease_list = []
unique_symptoms_list = []
matching_symptom_str = ','.join(matching_symptoms)
# Retrieve matching diseases based on symptoms in SQLite
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases WHERE symptoms LIKE ?",
(f'%{matching_symptom_str}%',)):
disease_info = {
'Disease': row[0],
'Symptoms': row[1].split(','),
'Treatments': row[2]
}
disease_list.append(disease_info)
# Add symptoms to the unique list, converting to lowercase to avoid duplicates
for symptom in row[1].split(','):
symptom_lower = symptom.strip().lower()
if symptom_lower not in unique_symptoms_list:
unique_symptoms_list.append(symptom_lower)
conn.close()
return disease_list, unique_symptoms_list
@app.post("/find_matching_symptoms")
def find_matching_symptoms(query: SymptomQuery):
symptoms = query.symptom.split(',')
all_results = []
for symptom in symptoms:
symptom = symptom.strip()
query_embedding = model.encode([symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=3
)
all_results.extend(results['documents'][0])
matching_symptoms = list(dict.fromkeys(all_results))
return {"matching_symptoms": matching_symptoms}
@app.post("/find_disease_list")
def find_disease_list(query: SymptomQuery):
# Normalize and embed each input symptom
selected_symptoms = [symptom.strip().lower() for symptom in query.symptom.split(',')]
all_selected_symptoms.update(selected_symptoms) # Add new symptoms to the set
all_results = []
for symptom in selected_symptoms:
# Generate the embedding for the current symptom
query_embedding = model.encode([symptom])
# Perform similarity search in ChromaDB
results = collection.query(
query_embeddings=query_embedding.tolist(),
n_results=5 # Return top 5 similar symptoms for each input symptom
)
# Aggregate the matching symptoms from the results
all_results.extend(results['documents'][0])
# Remove duplicates while preserving order
matching_symptoms = list(dict.fromkeys(all_results))
conn = sqlite3.connect("diseases_symptoms.db")
cursor = conn.cursor()
disease_list = []
unique_symptoms_set = set()
# Retrieve diseases that contain any of the matching symptoms
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"):
disease_name = row[0]
disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')] # Normalize database symptoms
treatments = row[2]
# Check if there is any overlap between matching symptoms and the disease symptoms
matched_symptoms = [symptom for symptom in matching_symptoms if symptom in disease_symptoms]
if matched_symptoms: # Include disease if there is at least one matching symptom
disease_info = {
'Disease': disease_name,
'Symptoms': disease_symptoms,
'Treatments': treatments
}
disease_list.append(disease_info)
# Add symptoms not yet selected by the user to unique symptoms list
for symptom in disease_symptoms:
if symptom not in selected_symptoms:
unique_symptoms_set.add(symptom)
conn.close()
# Convert unique symptoms set to a sorted list for consistent output
unique_symptoms_list = sorted(unique_symptoms_set)
return {
"disease_list": disease_list,
"unique_symptoms_list": unique_symptoms_list
}
class SelectedSymptomsQuery(BaseModel):
selected_symptoms: list
# Initialize global list for persistent selected symptoms
all_selected_symptoms = set() # Use a set to avoid duplicates
@app.post("/find_disease")
def find_disease(query: SelectedSymptomsQuery):
# Normalize input symptoms and add them to global list
new_symptoms = [symptom.strip().lower() for symptom in query.selected_symptoms]
all_selected_symptoms.update(new_symptoms) # Add new symptoms to the set
conn = sqlite3.connect("diseases_symptoms.db")
cursor = conn.cursor()
disease_list = []
unique_symptoms_set = set()
# Fetch all diseases and calculate matching symptoms
for row in cursor.execute("SELECT name, symptoms, treatments FROM diseases"):
disease_name = row[0]
disease_symptoms = [symptom.strip().lower() for symptom in row[1].split(',')]
treatments = row[2]
# Find common symptoms between all selected and disease symptoms
matched_symptoms = [symptom for symptom in all_selected_symptoms if symptom in disease_symptoms]
# Check for full match between known symptoms and disease symptoms
if len(matched_symptoms) == len(all_selected_symptoms):
disease_info = {
'Disease': disease_name,
'Symptoms': disease_symptoms,
'Treatments': treatments
}
disease_list.append(disease_info)
# Add symptoms not yet selected by the user to unique symptoms list
for symptom in disease_symptoms:
if symptom not in all_selected_symptoms:
unique_symptoms_set.add(symptom)
conn.close()
# Convert unique symptoms set to a sorted list for consistent output
unique_symptoms_list = sorted(unique_symptoms_set)
return {
"unique_symptoms_list": unique_symptoms_list,
"all_selected_symptoms": list(all_selected_symptoms), # Convert set to list for JSON response
"disease_list": disease_list
}
class DiseaseDetail(BaseModel):
Disease: str
Symptoms: list
Treatments: str
MatchCount: int
@app.post("/pass2llm")
def pass2llm(query: DiseaseDetail):
headers = {
"Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG",
"Ngrok-Version": "2"
}
response = requests.get("https://api.ngrok.com/endpoints", headers=headers)
if response.status_code == 200:
llm_api_response = response.json()
public_url = llm_api_response['endpoints'][0]['public_url']
prompt = f"Here is a list of diseases and their details: {query}. Please generate a summary."
llm_headers = {"Content-Type": "application/json"}
llm_payload = {"model": "llama3", "prompt": prompt, "stream": False}
llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload)
if llm_response.status_code == 200:
llm_response_json = llm_response.json()
return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")}
else:
return {"message": "Failed to get response from LLM!", "error": llm_response.text}
else:
return {"message": "Failed to get public URL from Ngrok!", "error": response.text}
@app.post("/trigger-reload")
async def trigger_reload():
global all_selected_symptoms
all_selected_symptoms.clear()
return "cleared"
# To run the FastAPI app with Uvicorn
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
|