HengJay's picture
Update ITRI logo in title.
2b8f833
raw
history blame
11.1 kB
import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Set streamlit configuration with disable XSRF protection
st.config.set_option("server.enableXsrfProtection", False)
st.set_page_config(page_title="Dysphagia Analysis", page_icon="👅")
# Function to plot the EMG signal Coordination Analysis
def emg_plot(event_index, event_plot_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t):
"""
Plots a 2D quadrant chart for EMG signal analysis with colored squares indicating the quadrant.
Parameters:
std (float): Standard deviation value of the EMG signal.
delta_t (float): Delta T value of the EMG signal.
"""
# Create a new figure
fig, ax = plt.subplots(figsize=(8, 8))
# Determine the quadrant and plot the colored square
if left_std_ratio > 3 and left_delta_t > 0:
# First quadrant
ax.add_patch(plt.Rectangle((2, 2), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio <= 3 and left_delta_t > 0:
# Second quadrant
ax.add_patch(plt.Rectangle((-8, 2), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio <= 3 and left_delta_t <= 0:
# Third quadrant
ax.add_patch(plt.Rectangle((-8, -8), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio > 3 and left_delta_t <= 0:
# Fourth quadrant
ax.add_patch(plt.Rectangle((2, -8), 6, 6, color='blue', alpha=0.5))
# Determine the quadrant and plot the colored square
if right_std_ratio > 3 and right_delta_t > 0:
# First quadrant
ax.add_patch(plt.Rectangle((1.5, 1.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio <= 3 and right_delta_t > 0:
# Second quadrant
ax.add_patch(plt.Rectangle((-8.5, 1.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio <= 3 and right_delta_t <= 0:
# Third quadrant
ax.add_patch(plt.Rectangle((-8.5, -8.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio > 3 and right_delta_t <= 0:
# Fourth quadrant
ax.add_patch(plt.Rectangle((1.5, -8.5), 6, 6, color='green', alpha=0.5))
# Add horizontal and vertical lines to create quadrants
plt.axhline(y=0, color='black', linestyle='--')
plt.axvline(x=0, color='black', linestyle='--')
# Add quadrant labels
# Add styled text labels with colored box
def add_styled_text(x, y, text, va='bottom'):
# Create text box style
bbox_props = dict(
boxstyle='round,pad=0.5',
fc='#1f77b4', # 背景顏色(白色)
ec='#1f77b4', # 邊框顏色(藍色)
alpha=0.7, # 背景透明度
lw=1.5 # 邊框寬度
)
plt.text(x, y, text,
horizontalalignment='center',
verticalalignment=va,
bbox=bbox_props,
color='white',
fontweight='semibold',
fontsize=9)
# Add styled quadrant labels
add_styled_text(4, 0.5, "Exertion + / Coordination -", 'bottom')
add_styled_text(-4, 0.5, "Exertion - / Coordination -", 'bottom')
add_styled_text(-4, -0.5, "Exertion - / Coordination +", 'top')
add_styled_text(4, -0.5, "Exertion + / Coordination +", 'top')
# Add title and axis labels
plt.title(f'Muscle Coordination Analysis - {event_index}:{event_plot_name}', fontsize=14)
plt.xlabel('Exertion (Std Ratio > 3)', fontsize=12, fontweight='semibold')
plt.ylabel('Coordination (Delta T > 0)', fontsize=12, fontweight='semibold')
# Remove axis numbers and labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Set plot legend with color
plt.legend(['Left Swallowing Muscle', 'Right Swallowing Muscle'], loc='upper left', fontsize=10)
# Set the limits of the plot
plt.xlim(-10, 10)
plt.ylim(-10, 10)
# Display the plot
st.pyplot(plt.gcf())
#plt.show()
def main():
st.image("logo/itri_logo.jpg", width=600)
st.title('👅Dysphagia Analysis - by ITRI BDL')
# Initialize session state variables
if 'emg_data' not in st.session_state:
st.session_state.emg_data = None
if 'time_marker' not in st.session_state:
st.session_state.time_marker = None
if 'analysis_started' not in st.session_state:
st.session_state.analysis_started = False
if 'data_isready' not in st.session_state:
st.session_state.data_isready = False
# File Uploaders
uploaded_file1 = st.file_uploader("Choose the EMG_data CSV file", type="csv")
uploaded_file2 = st.file_uploader("Choose the time_marker CSV file", type="csv")
# Load data when files are uploaded
if uploaded_file1 is not None and uploaded_file2 is not None:
try:
st.session_state.emg_data = pd.read_csv(uploaded_file1, skiprows=[0,1,3,4])
st.session_state.time_marker = pd.read_csv(uploaded_file2)
st.success("Data loaded successfully!")
st.session_state.data_isready = True
except Exception as e:
st.error(f"Error: {e}")
# Load test data button
if st.button('Load Test Data', type="primary"):
st.session_state.emg_data = pd.read_csv('test-new/0-New_Task-recording-0.csv', skiprows=[0,1,3,4])
st.session_state.time_marker = pd.read_csv('test-new/time_marker.csv')
st.success("Test data loaded successfully!")
st.session_state.data_isready = True
# Display loaded data
if st.session_state.emg_data is not None:
st.subheader("EMG Data")
st.dataframe(st.session_state.emg_data)
if st.session_state.time_marker is not None:
st.subheader("Time Marker")
st.dataframe(st.session_state.time_marker)
# Analysis button
if st.session_state.data_isready:
st.subheader("Muscle Coordination Analysis")
if st.button('Start Analysis', type="primary"):
st.session_state.analysis_started = True
# Perform analysis if started
if st.session_state.analysis_started:
st.write('Analysis in progress...')
# Reset emg data index with Channels
emg_data = st.session_state.emg_data.set_index('Channels')
# Get signal data from difference of emg_data
signal_left_lateral = emg_data['17'] - emg_data['18']
signal_left_medial = emg_data['19'] - emg_data['20']
signal_right_lateral = emg_data['23'] - emg_data['24']
signal_right_medial = emg_data['21'] - emg_data['22']
# RMS caculation : Define the moving average window size
N = 25
# Function to calculate moving RMS
def moving_rms(signal, window_size):
rms = np.sqrt(pd.Series(signal).rolling(window=window_size).mean()**2)
rms.index = signal.index # Ensure the index matches the original signal
return rms
# Calculate moving RMS for each channel
signal_left_lateral_RMS = moving_rms(signal_left_lateral, N)
signal_left_medial_RMS = moving_rms(signal_left_medial, N)
signal_right_lateral_RMS = moving_rms(signal_right_lateral, N)
signal_right_medial_RMS = moving_rms(signal_right_medial, N)
# Time Marker Processing
time_marker = st.session_state.time_marker[['0-New_Task-recording_time(us)', 'description', 'tag']]
time_marker = time_marker.rename(columns={'0-New_Task-recording_time(us)': 'event_time'})
# Select column value with odd/even index
event_start_times = time_marker.loc[0::2]['event_time'].values.astype(int)
event_end_times = time_marker.loc[1::2]['event_time'].values.astype(int)
event_names = time_marker.loc[0::2]['description'].values
# Get signal basic 10s std
signal_left_lateral_basics_10s_std = signal_left_lateral_RMS.loc[: 10000000].std()
signal_right_lateral_basics_10s_std = signal_right_lateral_RMS.loc[: 10000000].std()
# Analyze event data
event_number = len(event_names)
for i in range(1, 2*event_number, 2):
event_name = event_names[i//2]
event_start_time = event_start_times[i//2]
event_end_time = event_end_times[i//2]
st.write(f"Event {i//2+1}: {event_name}")
st.write(f"Start time: {float(event_start_time)/1000000: .3f} sec, End time: {float(event_end_time)/1000000: .3f} sec")
# Get event signal data with event time duration
mask_LL = (signal_left_lateral_RMS.index >= event_start_time) & (signal_left_lateral_RMS.index <= event_end_time)
event_signal_LL = signal_left_lateral_RMS.iloc[mask_LL]
mask_LM = (signal_left_medial_RMS.index >= event_start_time) & (signal_left_medial_RMS.index <= event_end_time)
event_signal_LM = signal_left_medial_RMS.iloc[mask_LM]
mask_RL = (signal_right_lateral_RMS.index >= event_start_time) & (signal_right_lateral_RMS.index <= event_end_time)
event_signal_RL = signal_right_lateral_RMS.iloc[mask_RL]
mask_RM = (signal_right_medial_RMS.index >= event_start_time) & (signal_right_medial_RMS.index <= event_end_time)
event_signal_RM = signal_right_medial_RMS.iloc[mask_RM]
# Calculate std ratio
left_event_std = event_signal_LL.std()
left_std_ratio = left_event_std / signal_left_lateral_basics_10s_std
right_event_std = event_signal_RL.std()
right_std_ratio = right_event_std / signal_right_lateral_basics_10s_std
st.write(f"left std ratio: {left_std_ratio: .3f}, right std ratio: {right_std_ratio: .3f}")
# Get signal max value index
LL_max_index = event_signal_LL.idxmax()
LM_max_index = event_signal_LM.idxmax()
left_delta_t = LM_max_index - LL_max_index
st.write(f"LM_max_index: {float(LM_max_index)/1000000: .3f}, LL_max_index: {float(LL_max_index)/1000000: .3f}, left delta t: {float(left_delta_t)/1000000: .3f}")
RL_max_index = event_signal_RL.idxmax()
RM_max_index = event_signal_RM.idxmax()
right_delta_t = RM_max_index - RL_max_index
st.write(f"RM_max_index: {float(RM_max_index)/1000000: .3f}, RL_max_index: {float(RL_max_index)/1000000: .3f}, right delta t: {float(right_delta_t)/1000000: .3f}")
# Plot with each event data
emg_plot(i//2+1, event_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t)
st.write('Analysis completed!')
if __name__ == '__main__':
main()