Shakir60's picture
Update app.py
8461811 verified
import streamlit as st
import torch
from PIL import Image
import numpy as np
from transformers import ViTFeatureExtractor, ViTForImageClassification
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
import logging
import faiss
from typing import List, Dict
from datetime import datetime
from groq import Groq
import os
from functools import lru_cache
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class RAGSystem:
def __init__(self):
# Load models only when needed
self._embedding_model = None
self._vector_store = None
self._knowledge_base = None
@property
def embedding_model(self):
if self._embedding_model is None:
self._embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
return self._embedding_model
@property
def knowledge_base(self):
if self._knowledge_base is None:
self._knowledge_base = self.load_knowledge_base()
return self._knowledge_base
@property
def vector_store(self):
if self._vector_store is None:
self._vector_store = self.create_vector_store()
return self._vector_store
@staticmethod
@lru_cache(maxsize=1) # Cache the knowledge base
def load_knowledge_base() -> List[Dict]:
"""Load and preprocess knowledge base"""
kb = {
"spalling": [
{
"severity": "Critical",
"description": "Severe concrete spalling with exposed reinforcement",
"repair_method": "Remove deteriorated concrete, clean reinforcement",
"immediate_action": "Evacuate area, install support",
"prevention": "Regular inspections, waterproofing"
}
],
"structural_cracks": [
{
"severity": "High",
"description": "Active structural cracks >5mm width",
"repair_method": "Structural analysis, epoxy injection",
"immediate_action": "Install crack monitors",
"prevention": "Regular monitoring, load management"
}
],
"surface_deterioration": [
{
"severity": "Medium",
"description": "Surface scaling and deterioration",
"repair_method": "Surface preparation, patch repair",
"immediate_action": "Document extent, plan repairs",
"prevention": "Surface sealers, proper drainage"
}
],
"corrosion": [
{
"severity": "High",
"description": "Corrosion of reinforcement leading to cracks",
"repair_method": "Remove rust, apply inhibitors",
"immediate_action": "Isolate affected area",
"prevention": "Anti-corrosion coatings, proper drainage"
}
],
"efflorescence": [
{
"severity": "Low",
"description": "White powder deposits on concrete surfaces",
"repair_method": "Surface cleaning, sealant application",
"immediate_action": "Identify moisture source",
"prevention": "Improve waterproofing, reduce moisture ingress"
}
],
"delamination": [
{
"severity": "Medium",
"description": "Separation of layers in concrete",
"repair_method": "Resurface or replace delaminated sections",
"immediate_action": "Inspect bonding layers",
"prevention": "Proper curing and bonding agents"
}
],
"honeycombing": [
{
"severity": "Medium",
"description": "Voids in concrete caused by improper compaction",
"repair_method": "Grout injection, patch repair",
"immediate_action": "Assess structural impact",
"prevention": "Proper vibration during pouring"
}
],
"water_leakage": [
{
"severity": "High",
"description": "Water ingress through cracks or joints",
"repair_method": "Injection grouting, waterproofing membranes",
"immediate_action": "Stop water flow, apply sealants",
"prevention": "Drainage systems, joint sealing"
}
],
"settlement_cracks": [
{
"severity": "High",
"description": "Cracks due to uneven foundation settlement",
"repair_method": "Foundation underpinning, grouting",
"immediate_action": "Monitor movement, stabilize foundation",
"prevention": "Soil compaction, proper foundation design"
}
],
"shrinkage_cracks": [
{
"severity": "Low",
"description": "Minor cracks caused by shrinkage during curing",
"repair_method": "Sealant application, surface repairs",
"immediate_action": "Monitor cracks",
"prevention": "Proper curing and moisture control"
}
]
}
documents = []
for category, items in kb.items():
for item in items:
doc_text = f"Category: {category}\n"
for key, value in item.items():
doc_text += f"{key}: {value}\n"
documents.append({"text": doc_text, "metadata": {"category": category}})
return documents
def create_vector_store(self):
"""Create FAISS vector store"""
texts = [doc["text"] for doc in self.knowledge_base]
embeddings = self.embedding_model.encode(texts)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype('float32'))
return index
@lru_cache(maxsize=32) # Cache recent query results
def get_relevant_context(self, query: str, k: int = 2) -> str:
"""Retrieve relevant context based on query"""
try:
query_embedding = self.embedding_model.encode([query])
D, I = self.vector_store.search(np.array(query_embedding).astype('float32'), k)
context = "\n\n".join([self.knowledge_base[i]["text"] for i in I[0]])
return context
except Exception as e:
logger.error(f"Error retrieving context: {e}")
return ""
class ImageAnalyzer:
def __init__(self, model_name="microsoft/swin-base-patch4-window7-224-in22k"):
self.device = "cpu"
self.defect_classes = ["spalling", "structural_cracks", "surface_deterioration"]
self.model_name = model_name
self._model = None
self._feature_extractor = None
@property
def model(self):
if self._model is None:
self._model = self._load_model()
return self._model
@property
def feature_extractor(self):
if self._feature_extractor is None:
self._feature_extractor = self._load_feature_extractor()
return self._feature_extractor
def _load_feature_extractor(self):
"""Load the appropriate feature extractor based on model type"""
try:
if "swin" in self.model_name:
from transformers import AutoFeatureExtractor
return AutoFeatureExtractor.from_pretrained(self.model_name)
elif "convnext" in self.model_name:
from transformers import ConvNextFeatureExtractor
return ConvNextFeatureExtractor.from_pretrained(self.model_name)
else:
from transformers import ViTFeatureExtractor
return ViTFeatureExtractor.from_pretrained(self.model_name)
except Exception as e:
logger.error(f"Feature extractor initialization error: {e}")
return None
def _load_model(self):
try:
if "swin" in self.model_name:
from transformers import SwinForImageClassification
model = SwinForImageClassification.from_pretrained(
self.model_name,
num_labels=len(self.defect_classes),
ignore_mismatched_sizes=True
)
elif "convnext" in self.model_name:
from transformers import ConvNextForImageClassification
model = ConvNextForImageClassification.from_pretrained(
self.model_name,
num_labels=len(self.defect_classes),
ignore_mismatched_sizes=True
)
else:
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained(
self.model_name,
num_labels=len(self.defect_classes),
ignore_mismatched_sizes=True
)
model = model.to(self.device)
# Reinitialize the classifier layer
with torch.no_grad():
if hasattr(model, 'classifier'):
in_features = model.classifier.in_features
model.classifier = torch.nn.Linear(in_features, len(self.defect_classes))
elif hasattr(model, 'head'):
in_features = model.head.in_features
model.head = torch.nn.Linear(in_features, len(self.defect_classes))
return model
except Exception as e:
logger.error(f"Model initialization error: {e}")
return None
def preprocess_image(self, image_bytes):
"""Preprocess image for model input"""
return _cached_preprocess_image(image_bytes, self.model_name)
def analyze_image(self, image):
"""Analyze image for defects"""
try:
if self.model is None:
raise ValueError("Model not properly initialized")
inputs = self.feature_extractor(
images=image,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
confidence_threshold = 0.3
results = {
self.defect_classes[i]: float(probs[i])
for i in range(len(self.defect_classes))
if float(probs[i]) > confidence_threshold
}
if not results:
max_idx = torch.argmax(probs)
results = {self.defect_classes[int(max_idx)]: float(probs[max_idx])}
return results
except Exception as e:
logger.error(f"Analysis error: {str(e)}")
return None
@st.cache_data
def _cached_preprocess_image(image_bytes, model_name):
"""Cached version of image preprocessing"""
try:
image = Image.open(image_bytes)
if image.mode != 'RGB':
image = image.convert('RGB')
# Adjust size based on model requirements
if "convnext" in model_name:
width, height = 384, 384
else:
width, height = 224, 224
image = image.resize((width, height), Image.Resampling.LANCZOS)
return image
except Exception as e:
logger.error(f"Image preprocessing error: {e}")
return None
@st.cache_data
def get_groq_response(query: str, context: str) -> str:
"""Get response from Groq LLM with caching"""
try:
if not os.getenv("GROQ_API_KEY"):
return "Error: Groq API key not configured"
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
prompt = f"""Based on the following context about construction defects, answer the question.
Context: {context}
Question: {query}
Provide a detailed answer based on the given context."""
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": "You are a construction defect analysis expert."
},
{
"role": "user",
"content": prompt
}
],
model="llama-3.3-70b-versatile",
temperature=0.7,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Groq API error: {e}", exc_info=True)
return f"Error: Unable to get response from AI model. Exception: {str(e)}"
def main():
st.set_page_config(
page_title="Smart Construction Defect Analyzer",
page_icon="🏗️",
layout="wide"
)
st.title("🏗️ Smart Construction Defect Analyzer")
# Initialize systems in session state if not present
if 'analyzer' not in st.session_state:
st.session_state.analyzer = ImageAnalyzer()
if 'rag_system' not in st.session_state:
st.session_state.rag_system = RAGSystem()
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Image Analysis")
uploaded_file = st.file_uploader(
"Upload a construction image for analysis",
type=["jpg", "jpeg", "png"],
key="image_uploader" # Add key for proper state management
)
if uploaded_file is not None:
try:
# Create a placeholder for the image
image_placeholder = st.empty()
# Process image with progress indicator
with st.spinner('Processing image...'):
processed_image = st.session_state.analyzer.preprocess_image(uploaded_file)
if processed_image:
image_placeholder.image(processed_image, caption='Uploaded Image', use_container_width=True)
# Analyze image with progress bar
progress_bar = st.progress(0)
with st.spinner('Analyzing defects...'):
results = st.session_state.analyzer.analyze_image(processed_image)
progress_bar.progress(100)
if results:
st.success('Analysis complete!')
# Display results
st.subheader("Detected Defects")
fig, ax = plt.subplots(figsize=(8, 4))
defects = list(results.keys())
probs = list(results.values())
ax.barh(defects, probs)
ax.set_xlim(0, 1)
plt.tight_layout()
st.pyplot(fig)
most_likely_defect = max(results.items(), key=lambda x: x[1])[0]
st.info(f"Most likely defect: {most_likely_defect}")
else:
st.warning("No defects detected or analysis failed. Please try another image.")
else:
st.error("Failed to process image. Please try another one.")
except Exception as e:
st.error(f"Error processing image: {str(e)}")
logger.error(f"Process error: {e}")
with col2:
st.subheader("Ask About Defects")
user_query = st.text_input(
"Ask a question about the defects or repairs:",
help="Example: What are the repair methods for spalling?"
)
if user_query:
with st.spinner('Getting answer...'):
# Get context from RAG system
context = st.session_state.rag_system.get_relevant_context(user_query)
if context:
# Get response from Groq
response = get_groq_response(user_query, context)
if not response.startswith("Error"):
st.write("Answer:")
st.markdown(response)
else:
st.error(response)
with st.expander("View retrieved information"):
st.text(context)
else:
st.error("Could not find relevant information. Please try rephrasing your question.")
with st.sidebar:
st.header("About")
st.write("""
This tool helps analyze construction defects in images and provides
information about repair methods and best practices.
Features:
- Image analysis for defect detection
- Information lookup for repair methods
- Expert AI responses to your questions
""")
# Display API status
if os.getenv("GROQ_API_KEY"):
st.success("Groq API: Connected")
else:
st.error("Groq API: Not configured")
# Add settings section
st.subheader("Settings")
if st.button("Clear Cache"):
st.cache_data.clear()
st.success("Cache cleared!")
if __name__ == "__main__":
main()