-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutil.lua
More file actions
43 lines (40 loc) · 1.02 KB
/
Copy pathutil.lua
File metadata and controls
43 lines (40 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
--
-- Created by IntelliJ IDEA.
-- User: sidharth
-- Date: 3/20/16
-- Time: 10:59 PM
-- To change this template use File | Settings | File Templates.
--
local debugger = require 'fb.debugger'
function standard_eh(err)
print(err)
debugger.enter()
end
function ConvInit(model, name)
for k,v in pairs(model:findModules(name)) do
local n = v.kW*v.kH*v.nOutputPlane
v.weight:normal(0,math.sqrt(2/n))
if cudnn.version >= 4000 then
v.bias = nil
v.gradBias = nil
else
v.bias:zero()
end
end
end
function BNInit(model, name)
for k,v in pairs(model:findModules(name)) do
v.weight:fill(1)
v.bias:zero()
end
end
function kaimingInit(model)
ConvInit(model, 'cudnn.SpatialConvolution')
ConvInit(model, 'nn.SpatialConvolution')
BNInit(model, 'fbnn.SpatialBatchNormalization')
BNInit(model, 'cudnn.SpatialBatchNormalization')
BNInit(model, 'nn.SpatialBatchNormalization')
for k,v in pairs(model:findModules('nn.Linear')) do
v.bias:zero()
end
end