-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathcache_utils.py
More file actions
65 lines (49 loc) · 2.08 KB
/
cache_utils.py
File metadata and controls
65 lines (49 loc) · 2.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from scipy.sparse import csr_matrix, save_npz, load_npz
import glob
import pickle
import os
def cache_single_output(output, cache_file):
if isinstance(output, csr_matrix):
save_npz(cache_file, output)
else:
with open(cache_file, 'wb') as f:
pickle.dump(output, f)
print("Cached output to {}.".format(cache_file))
def load_single_cache_output(cache_file):
if '.npz' in cache_file:
output = load_npz(cache_file)
else:
with open(cache_file, 'rb') as f:
output = pickle.load(f)
print("Loaded cache from {}.".format(cache_file))
return output
def manage_cache(cache_dir, extensions):
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
if isinstance(extensions, str):
single_output = True
extensions = [extensions]
else:
single_output = False
def decorator_manage_cache(func):
def wrapper_cache_output(*args, **kwargs):
run_id = "-".join([str(arg) for arg in args])
cache_file = os.path.join(cache_dir, "{}_{}_{}{}".format(func.__name__, run_id, 0, extensions[0]))
if not os.path.isfile(cache_file):
outputs = func(*args, **kwargs)
if single_output:
outputs = [outputs]
for i, (output, extension) in enumerate(zip(outputs, extensions)):
cache_file = os.path.join(cache_dir, "{}_{}_{}{}".format(func.__name__, run_id, i, extension))
cache_single_output(output, cache_file)
num_cache_files = len(glob.glob(cache_dir + "/{}_{}_*.".format(func.__name__, run_id)))
outputs = []
for i, extension in enumerate(extensions):
cache_file = os.path.join(cache_dir, "{}_{}_{}{}".format(func.__name__, run_id, i, extension))
output = load_single_cache_output(cache_file)
outputs.append(output)
if single_output:
return outputs[0]
return tuple(outputs)
return wrapper_cache_output
return decorator_manage_cache