-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSNN.java
More file actions
70 lines (50 loc) · 2.35 KB
/
SNN.java
File metadata and controls
70 lines (50 loc) · 2.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
public class SNN {
public double lr;
public int hidden, output;
public Matrix weightsIH, weightsHO, result, biasH, biasO;
SNN(int input, int hidden, int output){
this.hidden = hidden;
this.output = output;
this.weightsIH = new Matrix(this.hidden, input);
this.weightsIH.randomise();
this.weightsHO = new Matrix(this.output, this.hidden);
this.weightsHO.randomise();
this.biasH = new Matrix(this.hidden, 1);
this.biasH.randomise();
this.biasO = new Matrix(this.output, 1);
this.biasO.randomise();
this.lr = Math.pow(10, -1);
}
public void train(Matrix input, Matrix target, boolean training){
Matrix hiddenInput = Matrix.dot(this.weightsIH, input);
hiddenInput = Matrix.add(hiddenInput, this.biasH);
hiddenInput = SNN.activate(hiddenInput, false);
Matrix out = Matrix.dot(this.weightsHO, hiddenInput);
out = Matrix.add(out, this.biasO);
out = SNN.activate(out, false);
this.result = out;
if(training){
Matrix outputError = Matrix.sub(target, out);
Matrix hiddenError = Matrix.dot(Matrix.transpose(this.weightsHO), outputError);
Matrix outputGradient = SNN.activate(out, true);
outputGradient = Matrix.schur(outputGradient, outputError);
outputGradient = Matrix.scMul(outputGradient, this.lr);
Matrix HOdeltas = Matrix.dot(outputGradient, Matrix.transpose(hiddenInput));
this.weightsHO = Matrix.add(this.weightsHO, HOdeltas);
this.biasO = Matrix.add(this.biasO, outputGradient);
Matrix hiddenGradient = SNN.activate(hiddenInput, true);
hiddenGradient = Matrix.schur(hiddenGradient, hiddenError);
hiddenGradient = Matrix.scMul(hiddenGradient, this.lr);
Matrix IHdeltas = Matrix.dot(hiddenGradient, Matrix.transpose(input));
this.weightsIH = Matrix.add(this.weightsIH, IHdeltas);
this.biasH = Matrix.add(this.biasH, hiddenGradient);
}
}
public static Matrix activate(Matrix m, boolean deriv){
return Matrix.RelU(m, deriv);
}
public Matrix predict(Matrix input){
this.train(input, new Matrix(0,0), false);
return this.result;
}
}