mgyigit commited on
Commit
b308ad4
·
verified ·
1 Parent(s): 3d6a4fb

Update src/vis_utils.py

Browse files
Files changed (1) hide show
  1. src/vis_utils.py +107 -0
src/vis_utils.py CHANGED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import re
4
+ import os
5
+ import json
6
+ import yaml
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import plotnine as p9
10
+
11
+ from about import *
12
+ global data_component, filter_component
13
+
14
+ def get_method_color(method):
15
+ return color_dict.get(method, 'black') # If method is not in color_dict, use black
16
+
17
+ def draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title):
18
+ df = pd.read_csv(CSV_RESULT_PATH)
19
+ # Filter the dataframe based on selected methods
20
+ filtered_df = df[df['method_name'].isin(methods_selected)]
21
+
22
+ def get_method_color(method):
23
+ return color_dict.get(method.upper(), 'black')
24
+
25
+ # Add a new column to the dataframe for the color
26
+ filtered_df['color'] = filtered_df['method_name'].apply(get_method_color)
27
+
28
+ adjust_text_dict = {
29
+ 'expand_text': (1.15, 1.4), 'expand_points': (1.15, 1.25), 'expand_objects': (1.05, 1.5),
30
+ 'expand_align': (1.05, 1.2), 'autoalign': 'xy', 'va': 'center', 'ha': 'center',
31
+ 'force_text': (.0, 1.), 'force_objects': (.0, 1.),
32
+ 'lim': 500000, 'precision': 1., 'avoid_points': True, 'avoid_text': True
33
+ }
34
+
35
+ # Create the scatter plot using plotnine (ggplot)
36
+ g = (p9.ggplot(data=filtered_df,
37
+ mapping=p9.aes(x=x_metric, # Use the selected x_metric
38
+ y=y_metric, # Use the selected y_metric
39
+ color='color', # Use the dynamically generated color
40
+ label='method_names')) # Label each point by the method name
41
+ + p9.geom_point(size=3) # Add points with no jitter, set point size
42
+ + p9.geom_text(nudge_y=0.02, size=8) # Add method names as labels, nudge slightly above the points
43
+ + p9.labs(title=title, x=f"{x_metric}", y=f"{y_metric}") # Dynamic labels for X and Y axes
44
+ + p9.scale_color_identity() # Use colors directly from the dataframe
45
+ + p9.theme(legend_position='none',
46
+ figure_size=(8, 8), # Set figure size
47
+ axis_text=p9.element_text(size=10),
48
+ axis_title_x=p9.element_text(size=12),
49
+ axis_title_y=p9.element_text(size=12))
50
+ )
51
+
52
+ # Save the plot as an image
53
+ save_path = "./plot_images" # Ensure this folder exists or adjust the path
54
+ os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
55
+ filename = os.path.join(save_path, title.replace(" ", "_") + "_Similarity_Scatter.png")
56
+
57
+ g.save(filename=filename, dpi=400)
58
+
59
+ return filename
60
+
61
+ def benchmark_plot(benchmark_type, methods_selected, x_metric, y_metric):
62
+ if benchmark_type == 'flexible':
63
+ # Use general visualizer logic
64
+ return general_visualizer_plot(methods_selected, x_metric=x_metric, y_metric=y_metric)
65
+ elif benchmark_type == 'similarity':
66
+ title = f"{x_metric} vs {y_metric}"
67
+ return draw_scatter_plot_similarity(methods_selected, x_metric, y_metric, title)
68
+ elif benchmark_type == 'Benchmark 3':
69
+ return benchmark_3_plot(x_metric, y_metric)
70
+ elif benchmark_type == 'Benchmark 4':
71
+ return benchmark_4_plot(x_metric, y_metric)
72
+ else:
73
+ return "Invalid benchmark type selected."
74
+
75
+
76
+ def get_baseline_df(selected_methods, selected_metrics):
77
+ df = pd.read_csv(CSV_RESULT_PATH)
78
+ present_columns = ["method_name"] + selected_metrics
79
+ df = df[df['method_name'].isin(selected_methods)][present_columns]
80
+ return df
81
+
82
+ def general_visualizer(methods_selected, x_metric, y_metric):
83
+ df = pd.read_csv(CSV_RESULT_PATH)
84
+ filtered_df = df[df['method_name'].isin(methods_selected)]
85
+
86
+ # Create a Seaborn lineplot with method as hue
87
+ plt.figure(figsize=(10, 8)) # Increase figure size
88
+ sns.lineplot(
89
+ data=filtered_df,
90
+ x=x_metric,
91
+ y=y_metric,
92
+ hue="method_name", # Different colors for different methods
93
+ marker="o", # Add markers to the line plot
94
+ )
95
+
96
+ # Add labels and title
97
+ plt.xlabel(x_metric)
98
+ plt.ylabel(y_metric)
99
+ plt.title(f'{y_metric} vs {x_metric} for selected methods')
100
+ plt.grid(True)
101
+
102
+ # Save the plot to display it in Gradio
103
+ plot_path = "plot.png"
104
+ plt.savefig(plot_path)
105
+ plt.close()
106
+
107
+ return plot_path