mgyigit commited on
Commit
81cc688
·
verified ·
1 Parent(s): 481c4ca

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +24 -9
src/vis_utils.py CHANGED
@@ -119,9 +119,13 @@ def plot_similarity_results(methods_selected, x_metric, y_metric, similarity_pat
119
 
120
  return filename
121
 
122
- def plot_function_results(file_path, aspect, metric, method_names):
 
 
 
 
123
  # Load data
124
- df = pd.read_csv(file_path)
125
 
126
  # Filter for selected methods
127
  df = df[df['Method'].isin(method_names)]
@@ -131,24 +135,35 @@ def plot_function_results(file_path, aspect, metric, method_names):
131
  df = df[['Method'] + columns_to_plot]
132
  df.set_index('Method', inplace=True)
133
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  # Create clustermap
135
  g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
136
-
 
 
 
137
  # Get heatmap axis and customize labels
138
  ax = g.ax_heatmap
139
  ax.set_xlabel("")
140
  ax.set_ylabel("")
141
-
142
- # Apply color and caret adjustments to x-axis labels
143
- set_colors_and_marks_for_representation_groups(ax)
144
 
145
  # Save the plot as an image
146
- save_path = "./plot_images" # Ensure this folder exists or adjust the path
147
- os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
148
  filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png")
149
  plt.savefig(filename, dpi=400, bbox_inches='tight')
150
  plt.close() # Close the plot to free memory
151
-
152
  return filename
153
 
154
  def plot_family_results(file_path, method_names, metric, save_path="./plot_images"):
 
119
 
120
  return filename
121
 
122
+ def plot_function_results(file_path, aspect, metric, method_names, function_path="/tmp/function_results.csv"):
123
+ if not os.path.exists(function_path):
124
+ benchmark_types = ["similarity", "function", "family", "affinity"] #download all files for faster results later
125
+ download_from_hub(benchmark_types)
126
+
127
  # Load data
128
+ df = pd.read_csv(function_path)
129
 
130
  # Filter for selected methods
131
  df = df[df['Method'].isin(method_names)]
 
135
  df = df[['Method'] + columns_to_plot]
136
  df.set_index('Method', inplace=True)
137
 
138
+ # Fill missing values with 0
139
+ df = df.fillna(0)
140
+
141
+ # Generate colors for methods
142
+ row_color_dict = {method: get_method_color(method) for method in df.index}
143
+
144
+ long_form_mapping = {
145
+ "MF": "Molecular Function",
146
+ "BP": "Biological Process",
147
+ "CC": "Cellular Component"
148
+ }
149
+
150
  # Create clustermap
151
  g = sns.clustermap(df, annot=True, cmap="YlGnBu", row_cluster=False, col_cluster=False, figsize=(15, 15))
152
+
153
+ title = f"{long_form_mapping[aspect.capitalize()]} Results for {metric.capitalize()}"
154
+ g.fig.suptitle(title, x=0.5, y=1.02, fontsize=16, ha='center') # Center the title above the plot
155
+
156
  # Get heatmap axis and customize labels
157
  ax = g.ax_heatmap
158
  ax.set_xlabel("")
159
  ax.set_ylabel("")
 
 
 
160
 
161
  # Save the plot as an image
162
+ save_path = "/tmp"
 
163
  filename = os.path.join(save_path, f"{aspect}_{metric}_heatmap.png")
164
  plt.savefig(filename, dpi=400, bbox_inches='tight')
165
  plt.close() # Close the plot to free memory
166
+
167
  return filename
168
 
169
  def plot_family_results(file_path, method_names, metric, save_path="./plot_images"):