-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpruning_tools.py
More file actions
86 lines (75 loc) · 3.46 KB
/
Copy pathpruning_tools.py
File metadata and controls
86 lines (75 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch.nn.utils.prune as prune
import torch
import numpy as np
class Regrowth(prune.BasePruningMethod):
'''Regrow pruned parameters
regrowth method specified in constructor should be
- "random" for regrowing random connections
or
- "magnitude" for regrowing important connections
'''
PRUNING_TYPE = 'global'
def __init__(self,
regrowth_method,
amount,
seed=5000):
'''Regrowth class constructor
args: regrowth_method -> "random" or "magnitude"
amount of connections to be regrown (0, 1)'''
self.regrowth_method = regrowth_method
self.amount = amount
self.seed = seed
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
# calculate complementary to given mask e.g. [0, 0, 1, 0] -> [1, 1, 0, 1]
complement_mask = torch.logical_xor(mask, torch.ones_like(mask)).type(mask.type())
num_pruned = int(torch.sum(complement_mask))
num_to_regrow = int(self.amount * num_pruned)
pruned_weight_indices = torch.nonzero(complement_mask)
if self.regrowth_method == 'random':
indices_of_chosen_ones = np.random.choice(num_pruned, size=num_to_regrow, replace=False)
chosen_ones = pruned_weight_indices[indices_of_chosen_ones]
mask[chosen_ones] = 1.
if self.regrowth_method == 'magnitude':
pruned_weights_mask = complement_mask
pruned_weights = t * pruned_weights_mask
pruned_weights_flat = pruned_weights.flatten()
mask_flat = mask.flatten()
most_important = torch.topk(torch.abs(pruned_weights_flat), num_to_regrow).indices
mask_flat[most_important] = 1.
mask = torch.reshape(mask_flat, tuple(mask.size()))
return mask
def regrowth_unstructured(module, name, regrowth_method, amount, seed=5000):
kwargs = {'regrowth_method': regrowth_method,
'amount': amount,
'seed': seed}
Regrowth.apply(module, name, **kwargs)
class RegrowthRigL(prune.BasePruningMethod):
PRUNING_TYPE = 'global'
def __init__(self,
amount):
self.amount = amount
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
# calculate complementary to given mask e.g. [0, 0, 1, 0] -> [1, 1, 0, 1]
complement_mask = torch.logical_xor(mask, torch.ones_like(mask)).type(mask.type())
num_pruned = int(torch.sum(complement_mask))
number_of_weights = torch.numel(t)
num_to_regrow = int(self.amount * number_of_weights)
#regrow
pruned_grads_mask = complement_mask
pruned_grads = t * pruned_grads_mask
pruned_grads_flat = pruned_grads.flatten()
mask_flat = mask.flatten()
most_important = torch.topk(torch.abs(pruned_grads_flat), num_to_regrow).indices
mask_flat[most_important] = 1.
mask = torch.reshape(mask_flat, tuple(mask.size()))
return mask
def regrowth_rigl(module, name, amount):
kwargs = {'amount': amount}
RegrowthRigL.apply(module, name, **kwargs)
if __name__ == "__main__":
regrowth_pruning = Regrowth(regrowth_method='magnitude', amount=0.5)
regrowth_pruning.compute_mask(torch.tensor([[20., 10., 1.],
[1., 40., 10.]]), torch.tensor([[0., 0., 1.],
[1., 0., 0.]]))