-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexport.go
More file actions
87 lines (67 loc) · 2.04 KB
/
export.go
File metadata and controls
87 lines (67 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package deeper
import (
"encoding/json"
"fmt"
"io"
"gonum.org/v1/gonum/mat"
)
type Exporter interface {
Save(dst io.Writer, src *Network) error
Load(dst *Network, src io.Reader) error
}
type Export struct{}
type jsonMatrix struct {
Rows int `json:"rows"`
Cols int `json:"cols"`
Data []float64 `json:"data"`
}
type jsonNetwork struct {
Sizes []int `json:"sizes"`
Weights []jsonMatrix `json:"weights"`
Biases []jsonMatrix `json:"biases"`
}
// NewExporter returns an interface for saving and loading trained models
func NewExporter() Exporter {
return &Export{}
}
// Save exports an existing and likely trained network to a destination
// (i.e., a file). This is responsibility of the caller to close the writer
func (e *Export) Save(dst io.Writer, src *Network) error {
j := jsonNetwork{}
for _, l := range src.Layers {
j.Sizes = append(j.Sizes, l.Rows())
if !l.IsInput() {
j.Weights = append(j.Weights, jsonMatrix{Data: l.Weights().RawMatrix().Data})
j.Biases = append(j.Biases, jsonMatrix{Data: l.Biases().RawMatrix().Data})
}
}
if err := json.NewEncoder(dst).Encode(j); err != nil {
return fmt.Errorf("could not marshal network: %w", err)
}
return nil
}
// Load loads a previously exported network from its saved state for inference.
// This is responsibility of the caller to close the reader.
func (e *Export) Load(dst *Network, src io.Reader) error {
j := jsonNetwork{}
if err := json.NewDecoder(src).Decode(&j); err != nil {
return fmt.Errorf("couldn't decode saved network: %w", err)
}
dst.Layers = make([]BackpropagationLayer, 0)
for i := range j.Sizes {
var l BackpropagationLayer
if i == 0 {
l = NewInputLayer(j.Sizes[i])
} else if i == len(j.Sizes)-1 {
l = NewOutputLayer(j.Sizes[i], NewSoftmax())
} else {
l = NewHiddenLayer(j.Sizes[i], NewSigmoid())
}
if !l.IsInput() {
l.SetWeights(mat.NewDense(j.Sizes[i], j.Sizes[i-1], j.Weights[i-1].Data))
l.SetBiases(mat.NewDense(j.Sizes[i], 1, j.Biases[i-1].Data))
}
dst.AddLayerWoWeightInitialization(l)
}
return nil
}