mgyigit commited on
Commit
6e059bb
·
verified ·
1 Parent(s): 2a5f723

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +17 -1
src/vis_utils.py CHANGED
@@ -152,8 +152,17 @@ def plot_function_results(method_names, aspect, metric, function_path="/tmp/func
152
  }
153
 
154
  # Create clustermap
155
- g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
156
 
 
 
 
 
 
 
 
 
 
157
  title = f"{long_form_mapping[aspect.upper()]} Results for {metric.capitalize()}"
158
  g.fig.suptitle(title, x=0.5, y=1.02, fontsize=16, ha='center') # Center the title above the plot
159
 
@@ -180,6 +189,12 @@ def plot_family_results(method_names, dataset, family_path="/tmp/family_results.
180
  # Filter by method names and selected dataset columns
181
  df = df[df['Method'].isin(method_names)]
182
 
 
 
 
 
 
 
183
  # Filter columns based on the dataset and metrics
184
  value_vars = [col for col in df.columns if col.startswith(f"{dataset}_") and "_" in col]
185
 
@@ -244,6 +259,7 @@ def plot_affinity_results(method_names, metric, affinity_path="/tmp/affinity_res
244
 
245
  # Gather columns related to the specified metric and validate
246
  metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
 
247
  df = df[['Method'] + metric_columns].set_index('Method')
248
 
249
  df = df.fillna(0)
 
152
  }
153
 
154
  # Create clustermap
155
+ g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=True, col_cluster=True, figsize=(15, 15))
156
 
157
+ for label in g.ax_heatmap.get_yticklabels():
158
+ method = label.get_text()
159
+ label.set_color(get_method_color(method))
160
+
161
+ # Apply color to column labels
162
+ for label in g.ax_heatmap.get_xticklabels():
163
+ method = label.get_text()
164
+ label.set_color(get_method_color(method))
165
+
166
  title = f"{long_form_mapping[aspect.upper()]} Results for {metric.capitalize()}"
167
  g.fig.suptitle(title, x=0.5, y=1.02, fontsize=16, ha='center') # Center the title above the plot
168
 
 
189
  # Filter by method names and selected dataset columns
190
  df = df[df['Method'].isin(method_names)]
191
 
192
+ mcc_columns = [col for col in df.columns if col.startswith(f"{dataset}_mcc_")]
193
+ df['Mean_MCC'] = df[mcc_columns].mean(axis=1)
194
+
195
+ # Sort the DataFrame by the mean MCC
196
+ df = df.sort_values(by='Mean_MCC', ascending=False)
197
+
198
  # Filter columns based on the dataset and metrics
199
  value_vars = [col for col in df.columns if col.startswith(f"{dataset}_") and "_" in col]
200
 
 
259
 
260
  # Gather columns related to the specified metric and validate
261
  metric_columns = [col for col in df.columns if col.startswith(f"{metric}_")]
262
+ df = df.sort_values(by=metric_columns, ascending=False)
263
  df = df[['Method'] + metric_columns].set_index('Method')
264
 
265
  df = df.fillna(0)