-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathinit.py
More file actions
65 lines (56 loc) · 2.46 KB
/
init.py
File metadata and controls
65 lines (56 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import GetKNearestNeighbor
from evaluations import RMSE_label
from readerdata import *
from Datapreprocessing import *
from broadlearningsystem import *
from sklearn import metrics
if __name__ == "__main__":
dataFile = './data/transratings.csv'
map_num= 25
enhance_num = 25
map_batchsize = 15
enh_batchsize= 10
k_user = 5
k_item = 5
EPOCH = 2
numberOfUser, numberOfItem, mtx_np =datareader(dataFile)
neighbor_user = GetKNearestNeighbor.k_neighbors(mtx_np, k_user, numberOfUser)
neighbor_item = GetKNearestNeighbor.k_neighbors(mtx_np.T, k_item, numberOfItem)
print("data reader complete")
traindata, testdata, trainlabel, testlabel = constuctinput(numberOfUser,numberOfItem,mtx_np,neighbor_user, k_user,neighbor_item, k_item)
trainlabel = trainlabel.flatten()
testlabel = testlabel.flatten()
bls = broadNet(map_num=map_num, # 初始时多少组mapping nodes
enhance_num=enhance_num, # 初始时多少enhancement nodes
EPOCH=EPOCH, # 训练多少轮
map_function='relu',
enhance_function='relu',
map_batchsize=map_batchsize, # 每一组的神经元个数
enh_batchsize=enh_batchsize,
DESIRED_ACC=0.95, # 期望达到的准确率
STEP=int(1) # 一次增加多少组enhancement nodes
)
labelunique = {}
num = 0
for i in range(len(trainlabel)):
if trainlabel[i] not in labelunique:
labelunique.update({trainlabel[i]: num})
num += 1
labels = sorted(labelunique)
starttime = datetime.datetime.now()
bls.fit(traindata, trainlabel)
endtime = datetime.datetime.now()
runtime = str((endtime - starttime).total_seconds())
print('the training time of BLS is {0} seconds'.format((endtime - starttime).total_seconds()))
pre = bls.predict(testdata)
teststarttime = datetime.datetime.now()
predictlabel = bls.weightPredict(testdata)
testendtime = datetime.datetime.now()
testtime = str((testendtime - teststarttime).total_seconds())
lista = []
for i in range(len(pre)):
lista.append(labels[pre[i]])
mae = str(metrics.mean_absolute_error(testlabel, lista))
rmse = str(RMSE_label(pre, testlabel, labels))
print(metrics.mean_absolute_error(testlabel, lista))
print(RMSE_label(pre, testlabel, labels))