Spaces:
Running
Running
File size: 11,093 Bytes
b99e090 4699f43 b99e090 4699f43 b99e090 2b8f833 4699f43 b99e090 4699f43 b99e090 4699f43 b99e090 4699f43 abac731 b99e090 4699f43 abac731 b99e090 4699f43 b99e090 4699f43 abac731 4699f43 b99e090 4699f43 b99e090 2b8f833 b99e090 4699f43 2b8f833 4699f43 b99e090 4699f43 b99e090 4699f43 b99e090 4699f43 abac731 4699f43 b99e090 4699f43 abac731 4699f43 abac731 b99e090 4699f43 abac731 4699f43 abac731 4699f43 abac731 4699f43 abac731 4699f43 b99e090 4699f43 b99e090 |
|
import streamlit as st
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Set streamlit configuration with disable XSRF protection
st.config.set_option("server.enableXsrfProtection", False)
st.set_page_config(page_title="Dysphagia Analysis", page_icon="👅")
# Function to plot the EMG signal Coordination Analysis
def emg_plot(event_index, event_plot_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t):
"""
Plots a 2D quadrant chart for EMG signal analysis with colored squares indicating the quadrant.
Parameters:
std (float): Standard deviation value of the EMG signal.
delta_t (float): Delta T value of the EMG signal.
"""
# Create a new figure
fig, ax = plt.subplots(figsize=(8, 8))
# Determine the quadrant and plot the colored square
if left_std_ratio > 3 and left_delta_t > 0:
# First quadrant
ax.add_patch(plt.Rectangle((2, 2), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio <= 3 and left_delta_t > 0:
# Second quadrant
ax.add_patch(plt.Rectangle((-8, 2), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio <= 3 and left_delta_t <= 0:
# Third quadrant
ax.add_patch(plt.Rectangle((-8, -8), 6, 6, color='blue', alpha=0.5))
elif left_std_ratio > 3 and left_delta_t <= 0:
# Fourth quadrant
ax.add_patch(plt.Rectangle((2, -8), 6, 6, color='blue', alpha=0.5))
# Determine the quadrant and plot the colored square
if right_std_ratio > 3 and right_delta_t > 0:
# First quadrant
ax.add_patch(plt.Rectangle((1.5, 1.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio <= 3 and right_delta_t > 0:
# Second quadrant
ax.add_patch(plt.Rectangle((-8.5, 1.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio <= 3 and right_delta_t <= 0:
# Third quadrant
ax.add_patch(plt.Rectangle((-8.5, -8.5), 6, 6, color='green', alpha=0.5))
elif right_std_ratio > 3 and right_delta_t <= 0:
# Fourth quadrant
ax.add_patch(plt.Rectangle((1.5, -8.5), 6, 6, color='green', alpha=0.5))
# Add horizontal and vertical lines to create quadrants
plt.axhline(y=0, color='black', linestyle='--')
plt.axvline(x=0, color='black', linestyle='--')
# Add quadrant labels
# Add styled text labels with colored box
def add_styled_text(x, y, text, va='bottom'):
# Create text box style
bbox_props = dict(
boxstyle='round,pad=0.5',
fc='#1f77b4', # 背景顏色(白色)
ec='#1f77b4', # 邊框顏色(藍色)
alpha=0.7, # 背景透明度
lw=1.5 # 邊框寬度
)
plt.text(x, y, text,
horizontalalignment='center',
verticalalignment=va,
bbox=bbox_props,
color='white',
fontweight='semibold',
fontsize=9)
# Add styled quadrant labels
add_styled_text(4, 0.5, "Exertion + / Coordination -", 'bottom')
add_styled_text(-4, 0.5, "Exertion - / Coordination -", 'bottom')
add_styled_text(-4, -0.5, "Exertion - / Coordination +", 'top')
add_styled_text(4, -0.5, "Exertion + / Coordination +", 'top')
# Add title and axis labels
plt.title(f'Muscle Coordination Analysis - {event_index}:{event_plot_name}', fontsize=14)
plt.xlabel('Exertion (Std Ratio > 3)', fontsize=12, fontweight='semibold')
plt.ylabel('Coordination (Delta T > 0)', fontsize=12, fontweight='semibold')
# Remove axis numbers and labels
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
ax.set_yticklabels([])
# Set plot legend with color
plt.legend(['Left Swallowing Muscle', 'Right Swallowing Muscle'], loc='upper left', fontsize=10)
# Set the limits of the plot
plt.xlim(-10, 10)
plt.ylim(-10, 10)
# Display the plot
st.pyplot(plt.gcf())
#plt.show()
def main():
st.image("logo/itri_logo.jpg", width=600)
st.title('👅Dysphagia Analysis - by ITRI BDL')
# Initialize session state variables
if 'emg_data' not in st.session_state:
st.session_state.emg_data = None
if 'time_marker' not in st.session_state:
st.session_state.time_marker = None
if 'analysis_started' not in st.session_state:
st.session_state.analysis_started = False
if 'data_isready' not in st.session_state:
st.session_state.data_isready = False
# File Uploaders
uploaded_file1 = st.file_uploader("Choose the EMG_data CSV file", type="csv")
uploaded_file2 = st.file_uploader("Choose the time_marker CSV file", type="csv")
# Load data when files are uploaded
if uploaded_file1 is not None and uploaded_file2 is not None:
try:
st.session_state.emg_data = pd.read_csv(uploaded_file1, skiprows=[0,1,3,4])
st.session_state.time_marker = pd.read_csv(uploaded_file2)
st.success("Data loaded successfully!")
st.session_state.data_isready = True
except Exception as e:
st.error(f"Error: {e}")
# Load test data button
if st.button('Load Test Data', type="primary"):
st.session_state.emg_data = pd.read_csv('test-new/0-New_Task-recording-0.csv', skiprows=[0,1,3,4])
st.session_state.time_marker = pd.read_csv('test-new/time_marker.csv')
st.success("Test data loaded successfully!")
st.session_state.data_isready = True
# Display loaded data
if st.session_state.emg_data is not None:
st.subheader("EMG Data")
st.dataframe(st.session_state.emg_data)
if st.session_state.time_marker is not None:
st.subheader("Time Marker")
st.dataframe(st.session_state.time_marker)
# Analysis button
if st.session_state.data_isready:
st.subheader("Muscle Coordination Analysis")
if st.button('Start Analysis', type="primary"):
st.session_state.analysis_started = True
# Perform analysis if started
if st.session_state.analysis_started:
st.write('Analysis in progress...')
# Reset emg data index with Channels
emg_data = st.session_state.emg_data.set_index('Channels')
# Get signal data from difference of emg_data
signal_left_lateral = emg_data['17'] - emg_data['18']
signal_left_medial = emg_data['19'] - emg_data['20']
signal_right_lateral = emg_data['23'] - emg_data['24']
signal_right_medial = emg_data['21'] - emg_data['22']
# RMS caculation : Define the moving average window size
N = 25
# Function to calculate moving RMS
def moving_rms(signal, window_size):
rms = np.sqrt(pd.Series(signal).rolling(window=window_size).mean()**2)
rms.index = signal.index # Ensure the index matches the original signal
return rms
# Calculate moving RMS for each channel
signal_left_lateral_RMS = moving_rms(signal_left_lateral, N)
signal_left_medial_RMS = moving_rms(signal_left_medial, N)
signal_right_lateral_RMS = moving_rms(signal_right_lateral, N)
signal_right_medial_RMS = moving_rms(signal_right_medial, N)
# Time Marker Processing
time_marker = st.session_state.time_marker[['0-New_Task-recording_time(us)', 'description', 'tag']]
time_marker = time_marker.rename(columns={'0-New_Task-recording_time(us)': 'event_time'})
# Select column value with odd/even index
event_start_times = time_marker.loc[0::2]['event_time'].values.astype(int)
event_end_times = time_marker.loc[1::2]['event_time'].values.astype(int)
event_names = time_marker.loc[0::2]['description'].values
# Get signal basic 10s std
signal_left_lateral_basics_10s_std = signal_left_lateral_RMS.loc[: 10000000].std()
signal_right_lateral_basics_10s_std = signal_right_lateral_RMS.loc[: 10000000].std()
# Analyze event data
event_number = len(event_names)
for i in range(1, 2*event_number, 2):
event_name = event_names[i//2]
event_start_time = event_start_times[i//2]
event_end_time = event_end_times[i//2]
st.write(f"Event {i//2+1}: {event_name}")
st.write(f"Start time: {float(event_start_time)/1000000: .3f} sec, End time: {float(event_end_time)/1000000: .3f} sec")
# Get event signal data with event time duration
mask_LL = (signal_left_lateral_RMS.index >= event_start_time) & (signal_left_lateral_RMS.index <= event_end_time)
event_signal_LL = signal_left_lateral_RMS.iloc[mask_LL]
mask_LM = (signal_left_medial_RMS.index >= event_start_time) & (signal_left_medial_RMS.index <= event_end_time)
event_signal_LM = signal_left_medial_RMS.iloc[mask_LM]
mask_RL = (signal_right_lateral_RMS.index >= event_start_time) & (signal_right_lateral_RMS.index <= event_end_time)
event_signal_RL = signal_right_lateral_RMS.iloc[mask_RL]
mask_RM = (signal_right_medial_RMS.index >= event_start_time) & (signal_right_medial_RMS.index <= event_end_time)
event_signal_RM = signal_right_medial_RMS.iloc[mask_RM]
# Calculate std ratio
left_event_std = event_signal_LL.std()
left_std_ratio = left_event_std / signal_left_lateral_basics_10s_std
right_event_std = event_signal_RL.std()
right_std_ratio = right_event_std / signal_right_lateral_basics_10s_std
st.write(f"left std ratio: {left_std_ratio: .3f}, right std ratio: {right_std_ratio: .3f}")
# Get signal max value index
LL_max_index = event_signal_LL.idxmax()
LM_max_index = event_signal_LM.idxmax()
left_delta_t = LM_max_index - LL_max_index
st.write(f"LM_max_index: {float(LM_max_index)/1000000: .3f}, LL_max_index: {float(LL_max_index)/1000000: .3f}, left delta t: {float(left_delta_t)/1000000: .3f}")
RL_max_index = event_signal_RL.idxmax()
RM_max_index = event_signal_RM.idxmax()
right_delta_t = RM_max_index - RL_max_index
st.write(f"RM_max_index: {float(RM_max_index)/1000000: .3f}, RL_max_index: {float(RL_max_index)/1000000: .3f}, right delta t: {float(right_delta_t)/1000000: .3f}")
# Plot with each event data
emg_plot(i//2+1, event_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t)
st.write('Analysis completed!')
if __name__ == '__main__':
main() |