-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·42 lines (34 loc) · 967 Bytes
/
Copy pathtrain.py
File metadata and controls
executable file
·42 lines (34 loc) · 967 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python3
#
# Created on Wed Mar 17 2021
#
# Arthur Lang
# train.py
#
import numpy
from tensorflow.keras.datasets import mnist
from src.Model import Model
def to_float(data):
res=[]
for item in data:
res.append(float(item))
res=numpy.array(res)
return res
def dataPreprocessing():
# Mnist dataset
(dataTrain, labelTrain), (dataTest, labelTest) = mnist.load_data()
# normalize datas
maxVal = numpy.max(dataTrain)
dataTrain = (dataTrain / maxVal).reshape(60000, 28, 28, 1)
dataTest = (dataTest / maxVal).reshape(10000, 28, 28, 1)
return dataTrain, labelTrain, dataTest, labelTest
## train
# do a training session, saving the weights
def train():
dataTrain, labelTrain, dataTest, labelTest = dataPreprocessing()
nt = Model()
nt.setOutputWeightPath("ressources/weights/Mnist/")
nt.train(dataTrain, labelTrain, dataTest, labelTest)
return 0
if __name__ == "__main__":
train()