forked from jfsantos/seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathWeightNoise.lua
More file actions
38 lines (28 loc) · 739 Bytes
/
WeightNoise.lua
File metadata and controls
38 lines (28 loc) · 739 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local WeightNoise, parent = torch.class('nn.WeightNoise', 'nn.Module')
local eps = 1e-12
function WeightNoise:__init(parameters,sigma)
parent.__init(self)
self.sigma = sigma or 1e-3
self.weight = parameters:clone()
self.gradWeight = parameters:clone()
self.sample = parameters:clone()
end
function WeightNoise:getWeights()
return self.weight
end
function WeightNoise:Sample()
self.sample:randn(self.sample:size())
self.sample:mul(self.sigma)
self.sample:add(self.weight)
return self.sample
end
function WeightNoise:Mode()
return self.weight
end
function WeightNoise:updateOutput(nll)
self.nll = nll
return self.nll
end
function WeightNoise:accGradParameters(input, gradOutput)
self.gradWeight:add(gradOutput)
end