CohortBot / app.py
halimbahae's picture
Update app.py
d2f3083 verified
raw
history blame
6.29 kB
import gradio as gr
import pandas as pd
import re
from huggingface_hub import InferenceClient
import plotly.express as px
from collections import Counter
# Initialize Hugging Face client
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
def parse_message(message):
"""Extract information from a chat message using regex."""
info = {}
# Extract timestamp and phone number
timestamp_match = re.search(r'\[(.*?)\]', message)
phone_match = re.search(r'\] (.*?):', message)
if timestamp_match and phone_match:
info['timestamp'] = timestamp_match.group(1)
info['phone'] = phone_match.group(1)
# Extract rest of the message
content = message.split(':', 1)[1].strip()
# Extract name
name_match = re.match(r'^([^•\n-]+)', content)
if name_match:
info['name'] = name_match.group(1).strip()
# Extract affiliation
affiliation_match = re.search(r'[Aa]ffiliation:?\s*([^•\n]+)', content)
if affiliation_match:
info['affiliation'] = affiliation_match.group(1).strip()
# Extract research field/interests
field_match = re.search(r'([Ff]ield of [Ii]nterest|[Dd]omaine de recherche|[Rr]esearch area|[Aa]reas of interest):?\s*([^•\n]+)', content)
if field_match:
info['research_field'] = field_match.group(2).strip()
# Extract thesis topic
thesis_match = re.search(r'[Tt]hesis:?\s*([^•\n]+)', content)
if thesis_match:
info['thesis_topic'] = thesis_match.group(1).strip()
return info
def create_researcher_df(chat_history):
"""Convert chat messages to structured DataFrame."""
researchers = []
messages = chat_history.split('\n')
for message in messages:
if message.strip():
info = parse_message(message)
if info:
researchers.append(info)
df = pd.DataFrame(researchers)
return df
def analyze_research_fields(df):
"""Analyze and categorize research fields."""
if 'research_field' not in df.columns:
return pd.Series()
fields = df['research_field'].dropna()
# Split fields and flatten
all_fields = [field.strip().lower() for fields_list in fields for field in fields_list.split(',')]
return pd.Series(Counter(all_fields))
def create_visualizations(df):
"""Create visualizations from the researcher data."""
figures = []
# 1. Affiliation Distribution
if 'affiliation' in df.columns and not df['affiliation'].empty:
affiliation_counts = df['affiliation'].value_counts()
fig_affiliation = px.pie(
values=affiliation_counts.values,
names=affiliation_counts.index,
title='Distribution of Researchers by Affiliation'
)
figures.append(fig_affiliation)
# 2. Research Fields Analysis
field_counts = analyze_research_fields(df)
if not field_counts.empty:
fig_fields = px.bar(
x=field_counts.index,
y=field_counts.values,
title='Popular Research Fields',
labels={'x': 'Field', 'y': 'Count'}
)
figures.append(fig_fields)
return figures[0] if figures else None
def process_message(
message,
chat_history,
system_message,
max_tokens,
temperature,
top_p,
chat_history_text
):
"""Process message and return response with analysis."""
try:
# Process chat history if provided
if chat_history_text:
df = create_researcher_df(chat_history_text)
# Generate analysis summary
summary = f"Analysis of {len(df)} researchers:\n"
if 'affiliation' in df.columns:
summary += f"- Institutions represented: {df['affiliation'].nunique()}\n"
field_counts = analyze_research_fields(df)
if not field_counts.empty:
top_fields = field_counts.nlargest(3)
summary += "- Top research fields:\n"
for field, count in top_fields.items():
summary += f" • {field}: {count} researchers\n"
# Create visualization
fig = create_visualizations(df)
# Add analysis to message
message += f"\n\nCommunity Analysis:\n{summary}"
else:
fig = None
# Generate response using the LLM
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
response = client.chat_completion(
messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
bot_message = response.choices[0].message.content
chat_history.append((message, bot_message))
return chat_history, fig, chat_history
except Exception as e:
error_message = f"Error: {str(e)}"
chat_history.append((message, error_message))
return chat_history, None, chat_history
with gr.Blocks(title="CohortBot") as demo:
chatbot = gr.Chatbot(label="Chat History")
msg = gr.Textbox(label="Message", placeholder="Type your message here...")
system_msg = gr.Textbox(value="You are a friendly Research Community Chatbot.", label="System message")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
chat_history_text = gr.Textbox(label="Chat History for Analysis", lines=10)
plot = gr.Plot(label="Community Analysis")
msg.submit(
process_message,
[msg, chatbot, system_msg, max_tokens, temperature, top_p, chat_history_text],
[chatbot, plot, chatbot]
)
if __name__ == "__main__":
demo.launch()