-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_argmedian.py
More file actions
53 lines (43 loc) · 1.89 KB
/
Copy pathplot_argmedian.py
File metadata and controls
53 lines (43 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import dill
import argparse
from matplotlib import pyplot as plt
import numpy as np
def plot_metric():
sample_ids = list()
medians = [np.median(dump[dataset][metric][0]).round(3)]
for step in range(1, 6):
data = dump[dataset][metric][step]
medians.append(np.median(data).round(3))
argmedian = np.argsort(data)[len(data) // 2]
sample_ids.append(argmedian)
lines = list()
for iter_ in sample_ids:
line = [dump[dataset][metric][step][iter_] for step in range(0, n_step)]
lines.append(line)
color = ['green', 'blue', 'red', 'yellow', 'orange', 'purple']
for i, line in enumerate(lines):
plt.plot(list(range(0, n_step)), line, color=color[i], linestyle='--', marker='o', label=f'{i+1}-step')
plt.xlabel('Time step')
ylabel = 'NMI' if metric.startswith('normalized') else 'Pur'
plt.ylabel(ylabel)
#plt.legend()
for x, y in zip(list(range(0, n_step)), medians):
plt.text(x = x, y = y, s = f"{y}")
plt.savefig(output)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--n_iter", default=100, type=int, help="number of experiments")
parser.add_argument("--n_step", default=5, type=int, help="number of time steps")
parser.add_argument("--metric", default='normalized_mutual_info_score', type=str, help="metric to analyse")
parser.add_argument("--output", type=str, help="the path to the file where the clustering results will be stored")
parser.add_argument("--input", type=str, help="the path to containing the samples")
parser.add_argument("--dataset", type=str, help="the path to containing the samples")
args = parser.parse_args()
n_iter = args.n_iter
n_step = args.n_step + 1
dataset = args.dataset
metric = args.metric
output = args.output
dump = dill.load(open(args.input, mode='rb'))
plot_metric()