-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathviz_arm_means_multiple_exp.py
More file actions
120 lines (88 loc) · 4.5 KB
/
viz_arm_means_multiple_exp.py
File metadata and controls
120 lines (88 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.ticker as ticker
import argparse
import os
import glob
parser = argparse.ArgumentParser(description="Visualize arm means for a single experiment.")
parser.add_argument("--prolific_results_folder", type=str, required=True, help="Path to the folder containing Prolific data.")
parser.add_argument("--pattern_npz", type=str, required=True, help="Pattern to match the .npz files (e.g., 'bandit-results-*.npz').")
parser.add_argument("--ci", action="store_true", help="Whether to plot confidence intervals or std.")
parser.add_argument("--iter", type=int, default=None, help="Iteration to visualize (default = n_steps).")
parser.add_argument("--file_name", type=str, required=True, help="File name.")
parser.add_argument("--title_fig", type=str, default=None, help="Title of the figure.")
parser.add_argument("--min-arm", type=int, required=True, help="Minimum arm value for the x-axis.")
parser.add_argument("--ylim", type=float, nargs=2, default=None, help="Y-axis limits for the plot (optional).")
## Reading arguments
args = parser.parse_args()
pattern = args.pattern_npz
ci = args.ci
prolific_results_folder = args.prolific_results_folder
ITER = args.iter
file_name= args.file_name
title_fig = args.title_fig
min_arm = args.min_arm
ylim = args.ylim
#if title_fig is None:
# title_fig = "Filtering out users with at least $k$ pools\nwith at least 2 unassigned people"
# pattern ='bandit-results-h2000-s50-filtered_k*_m2_random-*.npz', 'bandit-results-h1000-s50-*_q46_nonfiltered_random-*.npz'
# 'bandit-results-h2000-s50-filtered_k*_m2_unassign-*.npz'
#'bandit-results-h2000-s50-filtered_k*_m2_random-*.npz'
files = sorted(glob.glob(os.path.join(prolific_results_folder, 'bandit', pattern)))
print(files)
SAVE_PATH = os.path.join(prolific_results_folder, 'bandit')
##*#################
##* PLOTTING FIGURE
##*#################
tt = ITER - 1
cmap = cm.get_cmap('cividis', len(files)) # distinct colors for each k
plt.style.use('science')
fig, ax = plt.subplots(figsize=(5.5, 3.5))
alg_means_b0 = []
for idx, npz_path in enumerate(files):
results = np.load(npz_path, allow_pickle=True)
arm_means = results['arm_means'][:,:,min_arm:]
arm_means_mean = arm_means.mean(axis=0)
arm_means_std = arm_means.std(axis=0)
arm_means_algorithm = results['arm_means'][:,:,0] # assuming the first arm is the algorithm
alg_means_b0.append(arm_means_algorithm.mean(axis=0)[tt])
arms = results['arms']
arms = arms[arms >= min_arm]
n_sims = results['num_sims']
ax.plot(arms, arm_means_mean[tt, :], marker='o', color=cmap(idx), label=f"idx = {idx}", alpha=.5)#$k={k_val}$
ci = 1.96 * arm_means_std[tt, :] / np.sqrt(n_sims) # 95% CI
ax.fill_between(arms,
arm_means_mean[tt, :] - ci,
arm_means_mean[tt, :] + ci,
color=cmap(idx), alpha=0.1)
b0_mean = np.array(alg_means_b0).mean()
ax.axhline(b0_mean, color=plt.get_cmap('Set2')(1), linestyle='--', zorder=0, linewidth=1.5)
ax.text(12, b0_mean + 0.0005, f"$b=0$", color='black', fontsize=12, ha='left', va='bottom')
ax.set_xlabel("$b$", fontsize=14)
ax.set_ylabel("Expected Utility", fontsize=14)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.xaxis.set_minor_locator(ticker.NullLocator()) #ax.xaxis.set_minor_locator(ticker.MultipleLocator(1))
ax.tick_params(axis='x', which='major', bottom=True, length=5, labelsize=14)
#ax.tick_params(axis='x', which='minor', bottom=True, length=3, labelsize=0) # labelsize=0 hides minor labels
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.tick_params(top=False, right=False, which='both')
ax.tick_params(axis='y', which='major', left=True, length=5, labelsize=14)
ax.tick_params(axis='y', which='minor', left=True, length=2, labelsize=0) # labelsize=0 hides minor labels
if ylim is not None:
ax.set_ylim(ylim[0], ylim[1])
#add colormap bar
sm = cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=len(files)-1))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, ticks=np.arange(len(files)), fraction=0.046, pad=0.04)
cbar.set_label('$u$', fontsize=12)
cbar.ax.set_yticklabels([f"{i}" for i in range(len(files))]) # Set custom tick labels
cbar.ax.tick_params(labelsize=12)
if title_fig is not None:
plt.title(title_fig, fontsize=14)
plt.tight_layout()
## Saving the figure
if not os.path.exists(SAVE_PATH):
os.makedirs(SAVE_PATH)
plt.savefig(os.path.join(SAVE_PATH, file_name+'.pdf'), bbox_inches='tight', pad_inches=0.1, dpi=300)