Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
import re | |
from huggingface_hub import InferenceClient | |
import plotly.express as px | |
from collections import Counter | |
# Initialize Hugging Face client | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
def parse_message(message): | |
"""Extract information from a chat message using regex.""" | |
info = {} | |
# Extract timestamp and phone number | |
timestamp_match = re.search(r'\[(.*?)\]', message) | |
phone_match = re.search(r'\] (.*?):', message) | |
if timestamp_match and phone_match: | |
info['timestamp'] = timestamp_match.group(1) | |
info['phone'] = phone_match.group(1) | |
# Extract rest of the message | |
content = message.split(':', 1)[1].strip() | |
# Extract name | |
name_match = re.match(r'^([^•\n-]+)', content) | |
if name_match: | |
info['name'] = name_match.group(1).strip() | |
# Extract affiliation | |
affiliation_match = re.search(r'[Aa]ffiliation:?\s*([^•\n]+)', content) | |
if affiliation_match: | |
info['affiliation'] = affiliation_match.group(1).strip() | |
# Extract research field/interests | |
field_match = re.search(r'([Ff]ield of [Ii]nterest|[Dd]omaine de recherche|[Rr]esearch area|[Aa]reas of interest):?\s*([^•\n]+)', content) | |
if field_match: | |
info['research_field'] = field_match.group(2).strip() | |
# Extract thesis topic | |
thesis_match = re.search(r'[Tt]hesis:?\s*([^•\n]+)', content) | |
if thesis_match: | |
info['thesis_topic'] = thesis_match.group(1).strip() | |
return info | |
def create_researcher_df(chat_history): | |
"""Convert chat messages to structured DataFrame.""" | |
researchers = [] | |
messages = chat_history.split('\n') | |
for message in messages: | |
if message.strip(): | |
info = parse_message(message) | |
if info: | |
researchers.append(info) | |
df = pd.DataFrame(researchers) | |
return df | |
def analyze_research_fields(df): | |
"""Analyze and categorize research fields.""" | |
if 'research_field' not in df.columns: | |
return pd.Series() | |
fields = df['research_field'].dropna() | |
# Split fields and flatten | |
all_fields = [field.strip().lower() for fields_list in fields for field in fields_list.split(',')] | |
return pd.Series(Counter(all_fields)) | |
def create_visualizations(df): | |
"""Create visualizations from the researcher data.""" | |
figures = [] | |
# 1. Affiliation Distribution | |
if 'affiliation' in df.columns and not df['affiliation'].empty: | |
affiliation_counts = df['affiliation'].value_counts() | |
fig_affiliation = px.pie( | |
values=affiliation_counts.values, | |
names=affiliation_counts.index, | |
title='Distribution of Researchers by Affiliation' | |
) | |
figures.append(fig_affiliation) | |
# 2. Research Fields Analysis | |
field_counts = analyze_research_fields(df) | |
if not field_counts.empty: | |
fig_fields = px.bar( | |
x=field_counts.index, | |
y=field_counts.values, | |
title='Popular Research Fields', | |
labels={'x': 'Field', 'y': 'Count'} | |
) | |
figures.append(fig_fields) | |
return figures[0] if figures else None | |
def process_message( | |
message, | |
chat_history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
chat_history_text | |
): | |
"""Process message and return response with analysis.""" | |
try: | |
# Process chat history if provided | |
if chat_history_text: | |
df = create_researcher_df(chat_history_text) | |
# Generate analysis summary | |
summary = f"Analysis of {len(df)} researchers:\n" | |
if 'affiliation' in df.columns: | |
summary += f"- Institutions represented: {df['affiliation'].nunique()}\n" | |
field_counts = analyze_research_fields(df) | |
if not field_counts.empty: | |
top_fields = field_counts.nlargest(3) | |
summary += "- Top research fields:\n" | |
for field, count in top_fields.items(): | |
summary += f" • {field}: {count} researchers\n" | |
# Create visualization | |
fig = create_visualizations(df) | |
# Add analysis to message | |
message += f"\n\nCommunity Analysis:\n{summary}" | |
else: | |
fig = None | |
# Generate response using the LLM | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in chat_history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
response = client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
bot_message = response.choices[0].message.content | |
chat_history.append((message, bot_message)) | |
return chat_history, fig, chat_history | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
chat_history.append((message, error_message)) | |
return chat_history, None, chat_history | |
with gr.Blocks(title="CohortBot") as demo: | |
chatbot = gr.Chatbot(label="Chat History") | |
msg = gr.Textbox(label="Message", placeholder="Type your message here...") | |
system_msg = gr.Textbox(value="You are a friendly Research Community Chatbot.", label="System message") | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p") | |
chat_history_text = gr.Textbox(label="Chat History for Analysis", lines=10) | |
plot = gr.Plot(label="Community Analysis") | |
msg.submit( | |
process_message, | |
[msg, chatbot, system_msg, max_tokens, temperature, top_p, chat_history_text], | |
[chatbot, plot, chatbot] | |
) | |
if __name__ == "__main__": | |
demo.launch() |