-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtestRun.lua
More file actions
91 lines (63 loc) · 2.81 KB
/
Copy pathtestRun.lua
File metadata and controls
91 lines (63 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
function testRun()
local theImages = torch.Tensor(opt.batchsize,3,240,320);
local theRealPos = torch.Tensor(opt.batchsize,3,120,160);
local theMag = torch.Tensor(opt.batchsize,2*opt.numSteps);
local thePos = torch.Tensor(opt.batchsize,opt.numSteps, 3,18);
local theChange = torch.Tensor(opt.batchsize,36,opt.numSteps);
local theMagDec = torch.Tensor(opt.batchsize,2*opt.numDecSteps);
local thePosDec = torch.Tensor(opt.batchsize,opt.numDecSteps, 3,18);
local theChangeDec = torch.Tensor(opt.batchsize,36,opt.numDecSteps);
local theHom = torch.Tensor(opt.batchsize,2,80,60);
theImages = theImages:cuda();
thePos = thePos:cuda();
thePosDec = thePosDec:cuda();
theChange = theChange:cuda();
theChangeDec = theChangeDec:cuda();
theHom = theHom:cuda();
theMag = theMag:cuda();
theMagDec = theMagDec:cuda();
theRealPos = theRealPos:cuda();
for i=1,opt.totalIter do
nClock = os.clock();
theHom = theHom:reshape(opt.batchsize,2,80,60);
theImages = theImages:reshape(opt.batchsize,3,240,320);
theRealPos = theRealPos:reshape(opt.batchsize,3,120,160);
loadup(theImages, theHom, thePos, theChange, theRealPos,
thePosDec, theChangeDec, theMag, theMagDec);
local tempPos = UpSample2:forward(theRealPos):clone();
theHom = theHom:reshape(opt.batchsize,2,80,60);
theHom = theHom:cuda();
local theHomBig = UpSample:forward(theHom);
theHomBig = theHomBig:reshape(opt.batchsize,2,320,240);
theImages = theImages:reshape(opt.batchsize,3,240,320);
theInputs = torch.cat(theImages, tempPos, 2);
theInputs = theInputs:transpose(4,3);
theInputs = fulltransform(theInputs, theHomBig);
theInputs = theInputs:transpose(4,3);
theInputs = theInputs:reshape(opt.batchsize,6,240,320);
theInputs = theInputs:transpose(2,1);
theInputs = theInputs/255.0;
theInputs[1] = theInputs[1] - pixMean[1];
theInputs[2] = theInputs[2] - pixMean[2];
theInputs[3] = theInputs[3] - pixMean[3];
theInputs[4] = theInputs[4] - pixMean[1];
theInputs[5] = theInputs[5] - pixMean[2];
theInputs[6] = theInputs[6] - pixMean[3];
theInputs = theInputs:transpose(2,1);
thePos = thePos:reshape(opt.batchsize, opt.numSteps, 3*18);
curExamp = i;
evaltest(thePos, thePosDec, theChangeDec, theMagDec, theChange, theMag)
err = 0;
errkld = 0;
errdec = 0;
lossFile = io.open(opt.outDir .. "loss.txt", "a");
print(string.format("Iteration %d ; Pred err = %f\n", i, err))
lossFile:write(string.format("Iteration %d ; Pred err = %f\n", i, err))
print(string.format("Iteration %d ; Dec err = %f\n", i, errdec))
lossFile:write(string.format("Iteration %d ; Dec err = %f\n", i, errdec))
print(string.format("Iteration %d ; KLD err = %f\n", i, errkld))
lossFile:write(string.format("Iteration %d ; KLD err = %f\n", i, errkld))
lossFile:close();
print(string.format("load %.2f \n", os.clock() - nClock))
end
end