pual commited on
Commit
ab091f1
·
1 Parent(s): 02eb730

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1408 -0
README.md CHANGED
@@ -1,3 +1,1411 @@
1
  ---
2
  license: mit
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ language:
4
+ - zh
5
+ pipeline_tag: image-classification
6
  ---
7
+ ```python
8
+ import numpy as np
9
+ import scipy.special as ssp
10
+ import matplotlib.pyplot as plt
11
+ ```
12
+
13
+
14
+ ```python
15
+ input_nodes=784 # 输入层节点数
16
+ hide_nodes=200 # 隐藏层节点数,理论上越高越好,但是高到一定程度就到顶了(默认:200)
17
+ out_nodes=10 # 输出层节点数
18
+ learningrate = 0.1 #学习率
19
+ ```
20
+
21
+
22
+ ```python
23
+ wih = np.random.normal(0.0, pow(hide_nodes, -0.5), (hide_nodes, input_nodes)) #矩阵大小为隐藏层节点数×输入层节点数
24
+ #np.random.normal()的意思是一个正态分布,normal这里是正态的意思
25
+ plt.hist(wih)
26
+ ```
27
+
28
+
29
+
30
+
31
+ (array([[ 1., 2., 3., ..., 8., 0., 0.],
32
+ [ 0., 0., 2., ..., 4., 0., 0.],
33
+ [ 0., 1., 5., ..., 9., 0., 0.],
34
+ ...,
35
+ [ 0., 1., 2., ..., 10., 0., 0.],
36
+ [ 0., 1., 13., ..., 7., 0., 0.],
37
+ [ 0., 2., 8., ..., 3., 1., 0.]]),
38
+ array([-0.32167192, -0.25702275, -0.19237358, -0.12772441, -0.06307524,
39
+ 0.00157393, 0.0662231 , 0.13087226, 0.19552143, 0.2601706 ,
40
+ 0.32481977]),
41
+ <a list of 784 BarContainer objects>)
42
+
43
+
44
+
45
+
46
+
47
+ ![png](output_2_1.png)
48
+
49
+
50
+
51
+
52
+ ```python
53
+ # Visualize weight matrix wih
54
+ plt.imshow(wih, cmap='coolwarm', aspect='auto')
55
+ #plt.imshow(wih, cmap='hot', aspect='auto')
56
+ plt.xlabel('Output Node')
57
+ plt.ylabel('Hidden Node')
58
+ plt.title('Weight Matrix (Hidden to input)')
59
+ plt.colorbar()
60
+ plt.show()
61
+ ```
62
+
63
+
64
+
65
+ ![png](output_3_0.png)
66
+
67
+
68
+
69
+
70
+ ```python
71
+ who = np.random.normal(0.0, pow(hide_nodes, -0.5), (out_nodes, hide_nodes)) #矩阵大小为输出层节点数×隐藏层节点数
72
+ plt.hist(who)
73
+ #同上
74
+ ```
75
+
76
+
77
+
78
+
79
+ (array([[0., 0., 1., ..., 0., 0., 0.],
80
+ [0., 0., 0., ..., 0., 0., 0.],
81
+ [0., 0., 0., ..., 1., 1., 0.],
82
+ ...,
83
+ [0., 0., 0., ..., 1., 1., 0.],
84
+ [0., 0., 0., ..., 1., 0., 0.],
85
+ [0., 1., 2., ..., 0., 0., 0.]]),
86
+ array([-0.26261651, -0.21194208, -0.16126765, -0.11059322, -0.05991879,
87
+ -0.00924436, 0.04143007, 0.0921045 , 0.14277893, 0.19345336,
88
+ 0.24412779]),
89
+ <a list of 200 BarContainer objects>)
90
+
91
+
92
+
93
+
94
+
95
+ ![png](output_4_1.png)
96
+
97
+
98
+
99
+
100
+ ```python
101
+ # Visualize weight matrix who
102
+ plt.imshow(who, cmap='coolwarm', aspect='auto')
103
+ plt.xlabel('Output Node')
104
+ plt.ylabel('Hidden Node')
105
+ plt.title('Weight Matrix (Hidden to Output)')
106
+ plt.colorbar()
107
+ plt.show()
108
+ ```
109
+
110
+
111
+
112
+ ![png](output_5_0.png)
113
+
114
+
115
+
116
+
117
+ ```python
118
+ #linspace 参考:https://blog.csdn.net/neweastsun/article/details/99676029
119
+ x = np.linspace(start=-6, stop=6, num=121) #从-6到6范围内创建121个距离相近的数字,从而生成x数组用于代入后面的y
120
+ '''
121
+ e.g.
122
+ x = np.linspace(start = 0, stop = 100, num = 5) ##从0到100范围内创建5个距离相近的数字
123
+ print(x)
124
+ OUT:[ 0. 25. 50. 75. 100.]
125
+
126
+ #lambda示例
127
+ #lambda arg1,arg2,arg3… :<表达式>
128
+ func=lambda x : x+1 #func=x+1
129
+ print(func(2)) #func=2+1=3
130
+ func=lambda x,y : x+y #func=x+y
131
+ print(func(1,2)) #func=1+2=3
132
+ '''
133
+ activation_function = lambda x: ssp.expit(x) #logistic sigmoid函数,定义为expit(x)= 1 /(1 + exp(-x))
134
+ y = activation_function(x)
135
+ plt.plot(x, y)
136
+ plt.xlabel('x')
137
+ plt.title('logistic sigmoid(x)')
138
+ plt.show()
139
+ ```
140
+
141
+
142
+
143
+ ![png](output_6_0.png)
144
+
145
+
146
+
147
+
148
+ ```python
149
+ #数据集分为训练集和测试集,训练集有60000条数据,测试集有10000条数据,
150
+ #每一条数据都是由785个数字组成,数值大小在0~255之间,第一个数字代表该条数据所表示的数字,
151
+ #后面的784个数字可以形成28×28的矩阵(28x28=784),每一个数值都对应该位置的像素点的像素值灰度大小,由此形成了一幅像素为28×28的图片。
152
+
153
+ #这里是训练集
154
+
155
+ test_data_file = open("mnist_train.csv", 'r')
156
+ test_data_list = test_data_file.readlines()
157
+ test_data_file.close()
158
+ print("总数据量:",len(test_data_list))
159
+ print("第1条数据:",test_data_list[0])
160
+ print("第1条数据表示的数字:",test_data_list[0][0])
161
+ print("第1条数据的28x28矩阵数据:",test_data_list[0][1:])
162
+ ```
163
+
164
+ 总数据量: 60000
165
+ 第1条数据: 5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
166
+
167
+ 第1条数据表示的数字: 5
168
+ 第1条数据的28x28矩阵数据: ,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,18,18,18,126,136,175,26,166,255,247,127,0,0,0,0,0,0,0,0,0,0,0,0,30,36,94,154,170,253,253,253,253,253,225,172,253,242,195,64,0,0,0,0,0,0,0,0,0,0,0,49,238,253,253,253,253,253,253,253,253,251,93,82,82,56,39,0,0,0,0,0,0,0,0,0,0,0,0,18,219,253,253,253,253,253,198,182,247,241,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,80,156,107,253,253,205,11,0,43,154,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,14,1,154,253,90,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,139,253,190,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,11,190,253,70,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,35,241,225,160,108,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,81,240,253,253,119,25,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,45,186,253,253,150,27,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,93,252,253,187,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,249,253,249,64,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,46,130,183,253,253,207,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,39,148,229,253,253,253,250,182,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,24,114,221,253,253,253,253,201,78,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,23,66,213,253,253,253,253,198,81,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,18,171,219,253,253,253,253,195,80,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55,172,226,253,253,253,253,244,133,11,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,136,253,253,253,212,135,132,16,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
169
+
170
+
171
+
172
+
173
+ ```python
174
+ all_values = test_data_list[0].split(',') # split()函数将第1条数据进行拆分,以‘,’为分界点进行拆分
175
+ image_array = np.asfarray(all_values[1:]).reshape((28,28)) # asfarray()函数将all_values中的后784个数字进行重新排列
176
+ # reshape()函数可以对数组进行整型,使其成为28×28的二维数组,asfarry()函数可以使其成为矩阵。
177
+ plt.imshow(image_array, interpolation = 'nearest') # imshow()函数可以将28×28的矩阵中的数值当做像素值,使其形成图片
178
+ ```
179
+
180
+
181
+
182
+
183
+ <matplotlib.image.AxesImage at 0x7fa3da4adfd0>
184
+
185
+
186
+
187
+
188
+
189
+ ![png](output_8_1.png)
190
+
191
+
192
+
193
+
194
+ ```python
195
+ #接下去是第1层和最后1层的逻辑
196
+ ```
197
+
198
+
199
+ ```python
200
+ # 对输入的数据进行处理,取后784个数据除以255,再乘以0.99,最后加上0。01,是所有的数据都在0.01到1.00之间
201
+ inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01 #输入层,784个输入
202
+ # 建立准确输出结果矩阵,对应的位置标签数值为0.99,其他位置为0.01
203
+ #最终实现将0~255转换为0~1的浮点数
204
+ #可视化中间输出
205
+ print(inputs)
206
+ middle_layer_fig = np.asfarray((inputs-0.01)/0.99*255.0 )
207
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((28,28))
208
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
209
+ ```
210
+
211
+ [0.01 0.01 0.01 0.01 0.01 0.01
212
+ 0.01 0.01 0.01 0.01 0.01 0.01
213
+ 0.01 0.01 0.01 0.01 0.01 0.01
214
+ 0.01 0.01 0.01 0.01 0.01 0.01
215
+ 0.01 0.01 0.01 0.01 0.01 0.01
216
+ 0.01 0.01 0.01 0.01 0.01 0.01
217
+ 0.01 0.01 0.01 0.01 0.01 0.01
218
+ 0.01 0.01 0.01 0.01 0.01 0.01
219
+ 0.01 0.01 0.01 0.01 0.01 0.01
220
+ 0.01 0.01 0.01 0.01 0.01 0.01
221
+ 0.01 0.01 0.01 0.01 0.01 0.01
222
+ 0.01 0.01 0.01 0.01 0.01 0.01
223
+ 0.01 0.01 0.01 0.01 0.01 0.01
224
+ 0.01 0.01 0.01 0.01 0.01 0.01
225
+ 0.01 0.01 0.01 0.01 0.01 0.01
226
+ 0.01 0.01 0.01 0.01 0.01 0.01
227
+ 0.01 0.01 0.01 0.01 0.01 0.01
228
+ 0.01 0.01 0.01 0.01 0.01 0.01
229
+ 0.01 0.01 0.01 0.01 0.01 0.01
230
+ 0.01 0.01 0.01 0.01 0.01 0.01
231
+ 0.01 0.01 0.01 0.01 0.01 0.01
232
+ 0.01 0.01 0.01 0.01 0.01 0.01
233
+ 0.01 0.01 0.01 0.01 0.01 0.01
234
+ 0.01 0.01 0.01 0.01 0.01 0.01
235
+ 0.01 0.01 0.01 0.01 0.01 0.01
236
+ 0.01 0.01 0.02164706 0.07988235 0.07988235 0.07988235
237
+ 0.49917647 0.538 0.68941176 0.11094118 0.65447059 1.
238
+ 0.96894118 0.50305882 0.01 0.01 0.01 0.01
239
+ 0.01 0.01 0.01 0.01 0.01 0.01
240
+ 0.01 0.01 0.12647059 0.14976471 0.37494118 0.60788235
241
+ 0.67 0.99223529 0.99223529 0.99223529 0.99223529 0.99223529
242
+ 0.88352941 0.67776471 0.99223529 0.94952941 0.76705882 0.25847059
243
+ 0.01 0.01 0.01 0.01 0.01 0.01
244
+ 0.01 0.01 0.01 0.01 0.01 0.20023529
245
+ 0.934 0.99223529 0.99223529 0.99223529 0.99223529 0.99223529
246
+ 0.99223529 0.99223529 0.99223529 0.98447059 0.37105882 0.32835294
247
+ 0.32835294 0.22741176 0.16141176 0.01 0.01 0.01
248
+ 0.01 0.01 0.01 0.01 0.01 0.01
249
+ 0.01 0.01 0.01 0.07988235 0.86023529 0.99223529
250
+ 0.99223529 0.99223529 0.99223529 0.99223529 0.77870588 0.71658824
251
+ 0.96894118 0.94564706 0.01 0.01 0.01 0.01
252
+ 0.01 0.01 0.01 0.01 0.01 0.01
253
+ 0.01 0.01 0.01 0.01 0.01 0.01
254
+ 0.01 0.01 0.32058824 0.61564706 0.42541176 0.99223529
255
+ 0.99223529 0.80588235 0.05270588 0.01 0.17694118 0.60788235
256
+ 0.01 0.01 0.01 0.01 0.01 0.01
257
+ 0.01 0.01 0.01 0.01 0.01 0.01
258
+ 0.01 0.01 0.01 0.01 0.01 0.01
259
+ 0.01 0.06435294 0.01388235 0.60788235 0.99223529 0.35941176
260
+ 0.01 0.01 0.01 0.01 0.01 0.01
261
+ 0.01 0.01 0.01 0.01 0.01 0.01
262
+ 0.01 0.01 0.01 0.01 0.01 0.01
263
+ 0.01 0.01 0.01 0.01 0.01 0.01
264
+ 0.01 0.54964706 0.99223529 0.74764706 0.01776471 0.01
265
+ 0.01 0.01 0.01 0.01 0.01 0.01
266
+ 0.01 0.01 0.01 0.01 0.01 0.01
267
+ 0.01 0.01 0.01 0.01 0.01 0.01
268
+ 0.01 0.01 0.01 0.01 0.01 0.05270588
269
+ 0.74764706 0.99223529 0.28176471 0.01 0.01 0.01
270
+ 0.01 0.01 0.01 0.01 0.01 0.01
271
+ 0.01 0.01 0.01 0.01 0.01 0.01
272
+ 0.01 0.01 0.01 0.01 0.01 0.01
273
+ 0.01 0.01 0.01 0.01 0.14588235 0.94564706
274
+ 0.88352941 0.63117647 0.42929412 0.01388235 0.01 0.01
275
+ 0.01 0.01 0.01 0.01 0.01 0.01
276
+ 0.01 0.01 0.01 0.01 0.01 0.01
277
+ 0.01 0.01 0.01 0.01 0.01 0.01
278
+ 0.01 0.01 0.01 0.32447059 0.94176471 0.99223529
279
+ 0.99223529 0.472 0.10705882 0.01 0.01 0.01
280
+ 0.01 0.01 0.01 0.01 0.01 0.01
281
+ 0.01 0.01 0.01 0.01 0.01 0.01
282
+ 0.01 0.01 0.01 0.01 0.01 0.01
283
+ 0.01 0.01 0.18470588 0.73211765 0.99223529 0.99223529
284
+ 0.59235294 0.11482353 0.01 0.01 0.01 0.01
285
+ 0.01 0.01 0.01 0.01 0.01 0.01
286
+ 0.01 0.01 0.01 0.01 0.01 0.01
287
+ 0.01 0.01 0.01 0.01 0.01 0.01
288
+ 0.01 0.07211765 0.37105882 0.98835294 0.99223529 0.736
289
+ 0.01 0.01 0.01 0.01 0.01 0.01
290
+ 0.01 0.01 0.01 0.01 0.01 0.01
291
+ 0.01 0.01 0.01 0.01 0.01 0.01
292
+ 0.01 0.01 0.01 0.01 0.01 0.01
293
+ 0.01 0.97670588 0.99223529 0.97670588 0.25847059 0.01
294
+ 0.01 0.01 0.01 0.01 0.01 0.01
295
+ 0.01 0.01 0.01 0.01 0.01 0.01
296
+ 0.01 0.01 0.01 0.01 0.01 0.01
297
+ 0.01 0.01 0.18858824 0.51470588 0.72047059 0.99223529
298
+ 0.99223529 0.81364706 0.01776471 0.01 0.01 0.01
299
+ 0.01 0.01 0.01 0.01 0.01 0.01
300
+ 0.01 0.01 0.01 0.01 0.01 0.01
301
+ 0.01 0.01 0.01 0.01 0.16141176 0.58458824
302
+ 0.89905882 0.99223529 0.99223529 0.99223529 0.98058824 0.71658824
303
+ 0.01 0.01 0.01 0.01 0.01 0.01
304
+ 0.01 0.01 0.01 0.01 0.01 0.01
305
+ 0.01 0.01 0.01 0.01 0.01 0.01
306
+ 0.10317647 0.45258824 0.868 0.99223529 0.99223529 0.99223529
307
+ 0.99223529 0.79035294 0.31282353 0.01 0.01 0.01
308
+ 0.01 0.01 0.01 0.01 0.01 0.01
309
+ 0.01 0.01 0.01 0.01 0.01 0.01
310
+ 0.01 0.01 0.09929412 0.26623529 0.83694118 0.99223529
311
+ 0.99223529 0.99223529 0.99223529 0.77870588 0.32447059 0.01776471
312
+ 0.01 0.01 0.01 0.01 0.01 0.01
313
+ 0.01 0.01 0.01 0.01 0.01 0.01
314
+ 0.01 0.01 0.01 0.01 0.07988235 0.67388235
315
+ 0.86023529 0.99223529 0.99223529 0.99223529 0.99223529 0.76705882
316
+ 0.32058824 0.04494118 0.01 0.01 0.01 0.01
317
+ 0.01 0.01 0.01 0.01 0.01 0.01
318
+ 0.01 0.01 0.01 0.01 0.01 0.01
319
+ 0.22352941 0.67776471 0.88741176 0.99223529 0.99223529 0.99223529
320
+ 0.99223529 0.95729412 0.52635294 0.05270588 0.01 0.01
321
+ 0.01 0.01 0.01 0.01 0.01 0.01
322
+ 0.01 0.01 0.01 0.01 0.01 0.01
323
+ 0.01 0.01 0.01 0.01 0.538 0.99223529
324
+ 0.99223529 0.99223529 0.83305882 0.53411765 0.52247059 0.07211765
325
+ 0.01 0.01 0.01 0.01 0.01 0.01
326
+ 0.01 0.01 0.01 0.01 0.01 0.01
327
+ 0.01 0.01 0.01 0.01 0.01 0.01
328
+ 0.01 0.01 0.01 0.01 0.01 0.01
329
+ 0.01 0.01 0.01 0.01 0.01 0.01
330
+ 0.01 0.01 0.01 0.01 0.01 0.01
331
+ 0.01 0.01 0.01 0.01 0.01 0.01
332
+ 0.01 0.01 0.01 0.01 0.01 0.01
333
+ 0.01 0.01 0.01 0.01 0.01 0.01
334
+ 0.01 0.01 0.01 0.01 0.01 0.01
335
+ 0.01 0.01 0.01 0.01 0.01 0.01
336
+ 0.01 0.01 0.01 0.01 0.01 0.01
337
+ 0.01 0.01 0.01 0.01 0.01 0.01
338
+ 0.01 0.01 0.01 0.01 0.01 0.01
339
+ 0.01 0.01 0.01 0.01 0.01 0.01
340
+ 0.01 0.01 0.01 0.01 0.01 0.01
341
+ 0.01 0.01 0.01 0.01 ]
342
+
343
+
344
+
345
+
346
+
347
+ <matplotlib.image.AxesImage at 0x7fa3da408d00>
348
+
349
+
350
+
351
+
352
+
353
+ ![png](output_10_2.png)
354
+
355
+
356
+
357
+
358
+ ```python
359
+ targets = np.zeros(out_nodes) + 0.01
360
+ #输出层,10个数字,10个输出,0~1的概率范围
361
+ #输出层是1个list,由10个数字组成,第一个数字代表0的概率,依次类推,第10个数字代表9的概率
362
+ #这里是输出的[理想结果]
363
+ # all_values[0] is the target label for this record
364
+ #可视化中间输出
365
+ print(len(targets))
366
+ print(targets)
367
+ middle_layer_fig = np.asfarray((targets-0.01)/0.99*255.0 )
368
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
369
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
370
+ ```
371
+
372
+ 10
373
+ [0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]
374
+
375
+
376
+
377
+
378
+
379
+ <matplotlib.image.AxesImage at 0x7fa3da3ad490>
380
+
381
+
382
+
383
+
384
+
385
+ ![png](output_11_2.png)
386
+
387
+
388
+
389
+
390
+ ```python
391
+ #print("第1行数据:",all_values)
392
+ #print("第1行数据所表示的数字:",all_values[0])
393
+ targets[int(all_values[0])] = 0.99
394
+ #将数据集的数据表示的数字在其指定的输出层的概率位置上的概率置0.99
395
+ #这里是第1行数据,对应的是数组5,因此按照其在输出层的表示的概率位置,应当将第6个数字改为0.99
396
+ #可视化中间输出
397
+ print(targets)
398
+ middle_layer_fig = np.asfarray((targets-0.01)/0.99*255.0 )
399
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
400
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
401
+ ```
402
+
403
+ [0.01 0.01 0.01 0.01 0.01 0.99 0.01 0.01 0.01 0.01]
404
+
405
+
406
+
407
+
408
+
409
+ <matplotlib.image.AxesImage at 0x7fa3da30b4f0>
410
+
411
+
412
+
413
+
414
+
415
+ ![png](output_12_2.png)
416
+
417
+
418
+
419
+
420
+ ```python
421
+ #对比
422
+ targets = np.zeros(out_nodes) + 0.01
423
+ print(targets)
424
+ targets[int(all_values[0])] = 0.99
425
+ print(targets)
426
+ ```
427
+
428
+ [0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]
429
+ [0.01 0.01 0.01 0.01 0.01 0.99 0.01 0.01 0.01 0.01]
430
+
431
+
432
+
433
+ ```python
434
+ #接下去是训练逻辑,训练的目标就是让输入的数据的概率尽可能接近理想结果
435
+ ```
436
+
437
+
438
+ ```python
439
+ # 将导入的输入列表数据和正确的输出结果转换成二维矩阵
440
+ INPUT = np.array(inputs, ndmin = 2).T # array函数是矩阵生成函数,将输入的inputs转换成二维矩阵,ndmin=2表示二维矩阵
441
+ TARGETS = np.array(targets, ndmin = 2).T # .T表示矩阵的转置,生成后的矩阵的转置矩阵送入变量targets
442
+ #print(INPUT)
443
+ #print(TARGETS)
444
+ ```
445
+
446
+
447
+ ```python
448
+ # 进行前向传播
449
+ # 利用导入的数据计算进入隐藏层的数据
450
+ hidden_inputs = np.dot(wih, INPUT) # dot()函数是指两个矩阵做点乘
451
+ #可视化中间输出
452
+ print(hidden_inputs.T)
453
+ # Visualize hidden layer activations
454
+ #hidden_inputs = hidden_inputs.reshape((20, 10))
455
+ plt.imshow(hidden_inputs.T, cmap='hot', aspect='auto')
456
+ plt.xlabel('Hidden Node')
457
+ plt.ylabel('Sample')
458
+ plt.title('Hidden Layer Activations')
459
+ plt.colorbar()
460
+ plt.show()
461
+ ```
462
+
463
+ [[-8.64964137e-01 -1.96617581e+00 7.43349647e-01 6.52699592e-01
464
+ 3.90933284e-01 1.23038702e+00 -1.26960367e-01 -8.89064451e-01
465
+ 1.25956352e-01 -2.66009122e-01 -3.87411628e-01 -8.55340714e-01
466
+ -3.73072385e-01 -4.88004264e-01 8.93519640e-01 -5.94015812e-01
467
+ -3.94940660e-01 -6.03127644e-01 -1.83468156e-01 1.21212338e+00
468
+ 1.11156836e+00 -3.30592481e-03 -1.45441494e-01 -1.16176875e-01
469
+ -6.79873194e-01 1.35716864e-03 -9.88715475e-01 2.53326180e-01
470
+ 7.95912751e-02 -8.71915904e-01 -4.99039240e-01 -9.32069427e-02
471
+ -1.29952079e+00 -1.18946859e-01 -2.22242548e-01 1.07578559e+00
472
+ 1.69691315e-01 -4.42288856e-01 1.18089766e+00 3.81134469e-02
473
+ 3.15796540e-01 1.07374634e+00 -7.71978830e-01 -1.77028239e-01
474
+ 7.83445294e-01 1.16099348e+00 5.28529106e-01 -1.94025187e-02
475
+ 2.00808369e-01 6.72844377e-01 1.21480995e+00 -2.05275063e-01
476
+ -1.02432531e+00 -1.40022847e+00 7.16467553e-01 -6.38000445e-01
477
+ -1.44617295e-01 4.72539610e-01 -6.51132050e-02 -1.02462391e+00
478
+ 1.38454078e+00 7.12628876e-01 7.39171671e-02 -3.34221329e-01
479
+ 5.03935486e-01 2.08522402e+00 2.29977865e-01 -8.58595299e-01
480
+ 9.14983758e-01 5.27664003e-02 -3.49103724e-01 -1.29338789e+00
481
+ 8.10453241e-01 2.08934398e+00 1.66835420e+00 -1.12660303e+00
482
+ -1.12181011e-01 1.70474734e-01 5.20577595e-01 6.00166910e-01
483
+ -3.81956593e-01 1.30122404e-01 -5.23356991e-01 -1.01661725e+00
484
+ -3.38834016e-01 6.30692963e-01 1.17169833e-01 9.13183907e-01
485
+ -1.10728477e+00 9.91458051e-01 -2.88315338e-01 7.70893096e-01
486
+ 5.82703388e-01 -9.29590575e-02 -1.26294025e+00 1.94053320e-01
487
+ -5.96912464e-01 2.60424259e-01 4.29504575e-02 -7.60243022e-01
488
+ 2.03240513e-02 7.27749904e-02 -7.19974851e-01 5.25634269e-01
489
+ -4.96678397e-01 -1.62713415e+00 2.89082887e-01 -5.26173924e-01
490
+ -3.82685176e-01 -1.76410064e+00 -1.33431697e+00 4.32481392e-01
491
+ 2.33941967e+00 7.52802920e-01 2.17849572e-01 -8.38437665e-02
492
+ -5.51882457e-01 1.84692442e+00 -4.10696115e-01 3.97851800e-01
493
+ -1.49071923e-01 -2.81875633e-01 1.95378425e+00 -4.66989868e-01
494
+ -4.73375650e-01 1.66522535e-01 5.01408007e-01 -1.30089311e-01
495
+ 1.44543864e+00 4.28063957e-01 3.86986466e-01 6.62182100e-01
496
+ -1.39480966e-01 -1.82625599e-01 -3.67218386e-01 -1.48826110e+00
497
+ -4.31214177e-01 -8.92040712e-01 -4.15032383e-01 -3.76042786e-01
498
+ -3.83971840e-01 7.49005651e-01 -3.16839497e-01 -7.70655367e-01
499
+ 3.56918546e-01 -1.93469779e-01 -4.51644191e-01 -5.20009826e-01
500
+ 7.61656212e-01 -5.39819400e-01 1.24457323e-01 4.02348827e-01
501
+ 4.96390519e-02 -1.61507281e-01 -6.04062425e-01 4.77674466e-01
502
+ 5.65500425e-01 -1.74931564e-02 1.82237163e-01 -2.52744493e-01
503
+ -9.74909666e-01 4.39247112e-01 2.50623145e-01 -5.47588554e-01
504
+ -1.10213410e+00 -7.96484480e-03 8.18154047e-01 -5.31161336e-01
505
+ 9.45395512e-02 -4.80934079e-02 -4.15248499e-01 2.01334670e-02
506
+ -7.73149020e-01 5.16150140e-01 -1.11187297e+00 -3.84973353e-01
507
+ 1.57056302e-01 9.52205562e-02 -4.17473666e-04 -2.64269971e-01
508
+ 3.51661057e-02 -8.62097845e-01 -6.41290441e-01 -6.10216699e-01
509
+ 1.48703377e+00 -9.36182669e-01 2.29758638e-01 2.69581850e-03
510
+ -9.90544195e-03 -1.16945542e-01 2.16055208e-01 -5.16034753e-01
511
+ -5.47460522e-01 1.21898405e+00 -1.40917054e-01 -1.10955125e+00
512
+ -1.06838867e+00 -8.16027514e-01 3.18583449e-01 7.11316110e-01]]
513
+
514
+
515
+
516
+
517
+ ![png](output_16_1.png)
518
+
519
+
520
+
521
+
522
+ ```python
523
+ # 利用激活函数sigmoid计算隐藏层输出的数据
524
+ hidden_outputs = activation_function(hidden_inputs)
525
+ #可视化中间输出
526
+ print(hidden_outputs.T)
527
+ # Visualize hidden layer activations
528
+ plt.imshow(hidden_outputs.T, cmap='hot', aspect='auto')
529
+ plt.xlabel('Hidden Node')
530
+ plt.ylabel('Sample')
531
+ plt.title('Hidden Layer Activations')
532
+ plt.colorbar()
533
+ plt.show()
534
+ ```
535
+
536
+ [[0.29630324 0.12280024 0.6777279 0.65761855 0.59650735 0.7738863
537
+ 0.46830247 0.29130293 0.53144752 0.43388711 0.40434055 0.29831372
538
+ 0.40779883 0.38036382 0.70961597 0.35571397 0.40252851 0.35362846
539
+ 0.45426119 0.77067444 0.75242139 0.49917352 0.46370359 0.4709884
540
+ 0.33628961 0.50033929 0.27116587 0.56299502 0.51988732 0.2948558
541
+ 0.37776648 0.47671512 0.21424568 0.4702983 0.44466693 0.74569562
542
+ 0.54232132 0.39119572 0.76510917 0.50952721 0.5782995 0.74530871
543
+ 0.3160512 0.45585816 0.68642218 0.76151319 0.62913998 0.49514952
544
+ 0.55003407 0.66213977 0.77114891 0.44886068 0.26418574 0.19777986
545
+ 0.67182867 0.34569868 0.46390856 0.61598467 0.48372745 0.2641277
546
+ 0.79971928 0.67098178 0.51847088 0.41721386 0.62338374 0.88945871
547
+ 0.55724239 0.29763291 0.71401891 0.51318854 0.41359978 0.21527993
548
+ 0.69220608 0.88986315 0.84135627 0.24478854 0.47198412 0.54251577
549
+ 0.62728282 0.64569449 0.40565508 0.53248478 0.37206759 0.26568684
550
+ 0.41609274 0.65264657 0.52925899 0.71365125 0.24837744 0.72937582
551
+ 0.42841635 0.68371406 0.64168922 0.47677696 0.22046816 0.54836166
552
+ 0.35505039 0.56474058 0.51073596 0.31859351 0.50508084 0.51818572
553
+ 0.32739852 0.6284643 0.37832157 0.16422333 0.57177159 0.3714097
554
+ 0.40547943 0.14627751 0.20844618 0.60646605 0.91208956 0.67978913
555
+ 0.55424802 0.47905133 0.36542777 0.86376559 0.39874522 0.59817142
556
+ 0.46280088 0.429994 0.87585869 0.38532895 0.38381758 0.5415347
557
+ 0.62279016 0.46752346 0.80929544 0.60541126 0.59555704 0.6597504
558
+ 0.46518618 0.45447007 0.40921333 0.18418287 0.39383643 0.29068888
559
+ 0.39770607 0.40708168 0.4051693 0.678962 0.42144618 0.31633735
560
+ 0.5882943 0.45178286 0.38896992 0.37284994 0.68171321 0.3682296
561
+ 0.53107423 0.59925186 0.51240722 0.45971072 0.35341483 0.61719858
562
+ 0.63772427 0.49562682 0.54543362 0.4371481 0.27390298 0.60807962
563
+ 0.56232987 0.36642406 0.24934024 0.4980088 0.69384436 0.37024607
564
+ 0.5236173 0.48797896 0.3976543 0.5050332 0.3157983 0.6262471
565
+ 0.24752187 0.40492795 0.53918356 0.52378717 0.49989563 0.43431435
566
+ 0.50879062 0.29690123 0.34495489 0.35200977 0.81563264 0.28167207
567
+ 0.5571883 0.50067395 0.49752366 0.47079689 0.55380467 0.37377991
568
+ 0.36645379 0.77188471 0.46482892 0.24795456 0.25570964 0.30660756
569
+ 0.57897899 0.67069191]]
570
+
571
+
572
+
573
+
574
+ ![png](output_17_1.png)
575
+
576
+
577
+
578
+
579
+ ```python
580
+ # 利用隐藏层输出的数据计算导入输出层的数据
581
+ final_inputs = np.dot(who, hidden_outputs) # dot()函数是指两个矩阵做点乘
582
+ #可视化中间输出
583
+ print(final_inputs.T)
584
+ middle_layer_fig = np.asfarray((final_inputs-0.01)/0.99*255.0 )
585
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
586
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
587
+ ```
588
+
589
+ [[ 0.56028136 0.82552015 0.34670209 0.17793798 -0.66372393 -0.37233255
590
+ -0.39555073 -0.76359914 -0.48399976 -0.23884983]]
591
+
592
+
593
+
594
+
595
+
596
+ <matplotlib.image.AxesImage at 0x7fa3d89ffc10>
597
+
598
+
599
+
600
+
601
+
602
+ ![png](output_18_2.png)
603
+
604
+
605
+
606
+
607
+ ```python
608
+ # Or visualize final outputs as a heatmap
609
+ plt.imshow(final_inputs, cmap='hot', aspect='auto')
610
+ plt.xlabel('Output Node')
611
+ plt.ylabel('Sample')
612
+ plt.title('Final Inputs')
613
+ plt.colorbar()
614
+ plt.show()
615
+ ```
616
+
617
+
618
+
619
+ ![png](output_19_0.png)
620
+
621
+
622
+
623
+
624
+ ```python
625
+ # Visualize final layer inputs
626
+ plt.bar(range(out_nodes), final_inputs.flatten())
627
+ plt.xlabel('Output Node')
628
+ plt.ylabel('Input Value')
629
+ plt.title('Final Layer Inputs')
630
+ plt.show()
631
+ ```
632
+
633
+
634
+
635
+ ![png](output_20_0.png)
636
+
637
+
638
+
639
+
640
+ ```python
641
+ # 利用激活函数sigmoid计算输出层的输出结果
642
+ final_outputs = activation_function(final_inputs)
643
+ # 前向传播结束
644
+
645
+ #可视化中间输出
646
+ print(final_outputs.T)
647
+ middle_layer_fig = np.asfarray((final_outputs-0.01)/0.99*255.0 )
648
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
649
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
650
+ ```
651
+
652
+ [[0.63651764 0.69540686 0.58581762 0.54436749 0.33990358 0.40797752
653
+ 0.40238179 0.31786536 0.38130809 0.44056981]]
654
+
655
+
656
+
657
+
658
+
659
+ <matplotlib.image.AxesImage at 0x7fa3da5f10a0>
660
+
661
+
662
+
663
+
664
+
665
+ ![png](output_21_2.png)
666
+
667
+
668
+
669
+
670
+ ```python
671
+ # Or visualize final outputs as a heatmap
672
+ plt.imshow(final_outputs, cmap='hot', aspect='auto')
673
+ plt.xlabel('Output Node')
674
+ plt.ylabel('Sample')
675
+ plt.title('Final Outputs')
676
+ plt.colorbar()
677
+ plt.show()
678
+ ```
679
+
680
+
681
+
682
+ ![png](output_22_0.png)
683
+
684
+
685
+
686
+
687
+ ```python
688
+ # Visualize final layer outputs (sigmoid)
689
+ plt.bar(range(out_nodes), final_outputs.flatten())
690
+ plt.xlabel('Output Node')
691
+ plt.ylabel('Input Value')
692
+ plt.title('Final Layer Inputs')
693
+ plt.show()
694
+ ```
695
+
696
+
697
+
698
+ ![png](output_23_0.png)
699
+
700
+
701
+
702
+
703
+ ```python
704
+ # 进行反向传播
705
+ # 计算前向传播得到的输出结果与正确值之间的误差
706
+ output_errors = TARGETS - final_outputs
707
+
708
+ #可视化中间输出
709
+ print(output_errors.T)
710
+ middle_layer_fig = np.asfarray((output_errors-0.01)/0.99*255.0 )
711
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
712
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
713
+ ```
714
+
715
+ [[-0.62651764 -0.68540686 -0.57581762 -0.53436749 -0.32990358 0.58202248
716
+ -0.39238179 -0.30786536 -0.37130809 -0.43056981]]
717
+
718
+
719
+
720
+
721
+
722
+ <matplotlib.image.AxesImage at 0x7fa3d87db8b0>
723
+
724
+
725
+
726
+
727
+
728
+ ![png](output_24_2.png)
729
+
730
+
731
+
732
+
733
+ ```python
734
+ # Visualize output errors as a bar chart
735
+ plt.bar(range(out_nodes), output_errors.flatten())
736
+ plt.xlabel('Output Node')
737
+ plt.ylabel('Error Value')
738
+ plt.title('Output Errors')
739
+ plt.show()
740
+ ```
741
+
742
+
743
+
744
+ ![png](output_25_0.png)
745
+
746
+
747
+
748
+
749
+ ```python
750
+ # Or visualize output errors as a scatter plot
751
+ plt.scatter(range(out_nodes), output_errors.flatten())
752
+ plt.xlabel('Output Node')
753
+ plt.ylabel('Error Value')
754
+ plt.title('Output Errors')
755
+ plt.show()
756
+ ```
757
+
758
+
759
+
760
+ ![png](output_26_0.png)
761
+
762
+
763
+
764
+
765
+ ```python
766
+ # 隐藏层的误差是由输出层的误差通过两个层之间的权重矩阵进行分配的,在隐藏层重新结合
767
+ ```
768
+
769
+
770
+ ```python
771
+ hidden_errors = np.dot(who.T, output_errors) # 隐藏层与输出层之间的权重矩阵的转置与前向传播的误差矩阵的点乘
772
+ #可视化中间输出
773
+ print(hidden_errors.T)
774
+ #middle_layer_fig = np.asfarray((hidden_errors-0.01)/0.99*255.0 )
775
+ #middle_layer_fig = np.asfarray(middle_layer_fig).reshape((20,10))
776
+ #plt.imshow(middle_layer_fig, interpolation = 'nearest')
777
+ ```
778
+
779
+ [[ 0.23145703 -0.09199276 -0.12220719 -0.15896069 0.06424253 0.10197068
780
+ -0.23125848 -0.00782811 -0.08381227 -0.11514534 -0.09644854 -0.12429981
781
+ 0.11276763 -0.26363747 -0.00989155 -0.14107911 0.27482566 0.10077863
782
+ 0.08727872 -0.12703169 0.04482464 0.07979755 -0.08780178 -0.10513761
783
+ -0.00644824 -0.11657829 -0.04453468 0.05577635 0.01531368 0.13738715
784
+ 0.03474212 0.22550981 -0.08763767 -0.06505764 -0.11262462 -0.04158586
785
+ -0.09128322 -0.01086248 0.05525096 -0.12434499 0.17656152 0.04339815
786
+ -0.03433653 -0.11152836 0.03669448 -0.01467246 0.01413861 0.17155288
787
+ -0.12223192 -0.10968683 0.10515451 0.14353315 0.08262463 0.16657906
788
+ -0.10807233 -0.10796653 -0.01689826 0.05175527 -0.02711501 -0.06925127
789
+ 0.24918363 -0.0658346 -0.01650576 -0.14181141 -0.06328054 0.11752269
790
+ 0.07361948 -0.25658514 -0.03837734 0.05291595 0.18022871 -0.02485894
791
+ -0.11155773 -0.17969543 0.05235072 -0.03868002 0.07991305 -0.00944794
792
+ 0.01358124 -0.04854606 -0.11433062 -0.11457118 -0.10174756 0.08157923
793
+ -0.07922054 0.16252699 -0.0668835 0.02633577 -0.25292949 -0.00164063
794
+ 0.17719827 -0.27838094 0.06372956 -0.08327759 -0.1045452 0.0994223
795
+ -0.18854096 0.01717639 -0.22337965 -0.05331426 -0.09068925 0.00909319
796
+ -0.11275048 0.02400681 0.15580461 0.04395622 0.05191163 0.07671998
797
+ -0.07357827 0.04857611 0.01200461 -0.01824155 0.20218933 -0.01648541
798
+ -0.08841815 -0.22972757 -0.06564815 0.25879827 0.03363929 -0.08144042
799
+ -0.00117747 0.04931258 -0.28733007 0.09207885 -0.11084745 0.03480787
800
+ -0.30290225 0.02605289 -0.03273764 0.13374028 0.06733113 -0.08264645
801
+ -0.10579 -0.16626817 -0.19349467 0.2339928 0.25338442 -0.04781617
802
+ 0.01431193 -0.06614716 -0.03706169 -0.18027598 0.03546684 0.07375848
803
+ -0.13524866 -0.14490857 -0.21459248 0.1796899 0.02376605 -0.02517879
804
+ 0.00632407 0.03003414 -0.11537092 0.03510202 0.07357026 0.0971219
805
+ -0.08266574 0.03720117 0.09910707 -0.04312925 -0.08307132 0.02983252
806
+ 0.01496464 0.07249455 -0.1618727 0.11377448 -0.03207163 0.19216192
807
+ 0.09118743 0.01690548 -0.06923089 0.02959015 0.20129512 -0.04899694
808
+ 0.1233579 -0.20508642 0.01812198 -0.00063595 0.17360329 0.11723159
809
+ 0.15777609 0.07835488 -0.05387801 -0.01755501 0.10815374 0.22098465
810
+ -0.12040005 0.025853 -0.08475004 0.24887947 0.07332807 0.0784619
811
+ 0.01351764 -0.08704183 0.08712977 0.0756019 -0.04051772 -0.15931343
812
+ -0.04228901 0.13588616]]
813
+
814
+
815
+
816
+ ```python
817
+ # Visualize hidden errors as a bar chart
818
+ plt.bar(range(hide_nodes), hidden_errors.flatten())
819
+ plt.xlabel('Hidden Node')
820
+ plt.ylabel('Error Value')
821
+ plt.title('Hidden Errors')
822
+ plt.show()
823
+ ```
824
+
825
+
826
+
827
+ ![png](output_29_0.png)
828
+
829
+
830
+
831
+
832
+ ```python
833
+ # Or visualize hidden errors as a scatter plot
834
+ plt.scatter(range(hide_nodes), hidden_errors.flatten())
835
+ plt.xlabel('Hidden Node')
836
+ plt.ylabel('Error Value')
837
+ plt.title('Hidden Errors')
838
+ plt.show()
839
+ ```
840
+
841
+
842
+
843
+ ![png](output_30_0.png)
844
+
845
+
846
+
847
+
848
+ ```python
849
+ # 对隐藏层与输出层之间的权重矩阵进行更新迭代
850
+ who += learningrate * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),np.transpose(hidden_outputs))
851
+ # 对输入层与隐藏层之间的权重矩阵进行更新迭代
852
+ wih += learningrate * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(INPUT))
853
+ ```
854
+
855
+
856
+ ```python
857
+ #第一次迭代训练结束
858
+ print(wih)
859
+ print(who)
860
+ ```
861
+
862
+ [[ 0.02067233 -0.07978803 0.03108053 ... -0.03073812 0.0557655
863
+ -0.05129495]
864
+ [ 0.07106607 0.08339657 -0.09380426 ... 0.0441884 -0.03837313
865
+ -0.08557481]
866
+ [ 0.05502248 -0.09130093 0.0384007 ... -0.00538593 0.06249898
867
+ 0.08624116]
868
+ ...
869
+ [-0.00994263 -0.07816935 -0.01082394 ... 0.00301429 -0.00230436
870
+ 0.09999818]
871
+ [ 0.03612263 -0.01946694 0.0954403 ... 0.01146139 -0.00025476
872
+ -0.12006706]
873
+ [-0.05983042 -0.01998364 -0.06092712 ... -0.02392167 -0.06806361
874
+ 0.01094472]]
875
+ [[-0.0204511 -0.07749507 -0.00194097 ... 0.09540347 0.008829
876
+ -0.02140005]
877
+ [-0.01264093 0.06889351 0.04956639 ... -0.0669025 -0.01843888
878
+ 0.00722866]
879
+ [-0.05481193 0.04000967 -0.09688887 ... 0.07287872 0.11162873
880
+ -0.12241058]
881
+ ...
882
+ [-0.0972931 0.05829893 0.13900051 ... 0.04472318 0.0444388
883
+ -0.1383636 ]
884
+ [ 0.0109094 0.01127165 0.00850074 ... 0.00947806 -0.08093348
885
+ -0.17257885]
886
+ [-0.08861324 0.04998882 0.03560659 ... 0.05427103 -0.06461784
887
+ 0.01395731]]
888
+
889
+
890
+
891
+ ```python
892
+ print(wih)
893
+ # Visualize weight matrix wih
894
+ plt.imshow(wih, cmap='coolwarm', aspect='auto')
895
+ plt.xlabel('Output Node')
896
+ plt.ylabel('Hidden Node')
897
+ plt.title('Weight Matrix (Hidden to input)')
898
+ plt.colorbar()
899
+ plt.show()
900
+ ```
901
+
902
+ [[ 0.02067233 -0.07978803 0.03108053 ... -0.03073812 0.0557655
903
+ -0.05129495]
904
+ [ 0.07106607 0.08339657 -0.09380426 ... 0.0441884 -0.03837313
905
+ -0.08557481]
906
+ [ 0.05502248 -0.09130093 0.0384007 ... -0.00538593 0.06249898
907
+ 0.08624116]
908
+ ...
909
+ [-0.00994263 -0.07816935 -0.01082394 ... 0.00301429 -0.00230436
910
+ 0.09999818]
911
+ [ 0.03612263 -0.01946694 0.0954403 ... 0.01146139 -0.00025476
912
+ -0.12006706]
913
+ [-0.05983042 -0.01998364 -0.06092712 ... -0.02392167 -0.06806361
914
+ 0.01094472]]
915
+
916
+
917
+
918
+
919
+ ![png](output_33_1.png)
920
+
921
+
922
+
923
+
924
+ ```python
925
+ print(who)
926
+ # Visualize weight matrix who
927
+ plt.imshow(who, cmap='coolwarm', aspect='auto')
928
+ plt.xlabel('Output Node')
929
+ plt.ylabel('Hidden Node')
930
+ plt.title('Weight Matrix (Hidden to Output)')
931
+ plt.colorbar()
932
+ plt.show()
933
+ ```
934
+
935
+ [[-0.0204511 -0.07749507 -0.00194097 ... 0.09540347 0.008829
936
+ -0.02140005]
937
+ [-0.01264093 0.06889351 0.04956639 ... -0.0669025 -0.01843888
938
+ 0.00722866]
939
+ [-0.05481193 0.04000967 -0.09688887 ... 0.07287872 0.11162873
940
+ -0.12241058]
941
+ ...
942
+ [-0.0972931 0.05829893 0.13900051 ... 0.04472318 0.0444388
943
+ -0.1383636 ]
944
+ [ 0.0109094 0.01127165 0.00850074 ... 0.00947806 -0.08093348
945
+ -0.17257885]
946
+ [-0.08861324 0.04998882 0.03560659 ... 0.05427103 -0.06461784
947
+ 0.01395731]]
948
+
949
+
950
+
951
+
952
+ ![png](output_34_1.png)
953
+
954
+
955
+
956
+
957
+ ```python
958
+ #完整训练流程
959
+ ```
960
+
961
+
962
+ ```python
963
+ input_nodes=784 # 输入层节点数
964
+ hide_nodes=200 # 隐藏层节点数
965
+ out_nodes=10 # 输出层节点数
966
+ learningrate = 0.1 #学习率
967
+ train_errors = []
968
+ epochs=5
969
+
970
+ wih = np.random.normal(0.0, pow(hide_nodes, -0.5), (hide_nodes, input_nodes)) #矩阵大小为隐藏层节点数×输入层节点数
971
+ #np.random.normal()的意思是一个正态分布,normal这里是正态的意思
972
+ who = np.random.normal(0.0, pow(hide_nodes, -0.5), (out_nodes, hide_nodes)) #矩阵大小为输出层节点数×隐藏层节点数
973
+ activation_function = lambda x: ssp.expit(x) #结合上述所学,这里写一段原理是logistic sigmoid的激活函数
974
+
975
+ test_data_file = open("mnist_train.csv", 'r')
976
+ test_data_list = test_data_file.readlines()
977
+ test_data_file.close()
978
+
979
+ for e in range(epochs):
980
+ # go through all records in the training data set
981
+ # 遍历所有输入的数据
982
+ print('epochs start:',e)
983
+ # 计算训练集上的误差
984
+ train_error = 0.0
985
+ for record in test_data_list:
986
+ all_values = record.split(',') # split()函数将第1条数据进行拆分,以‘,’为分界点进行拆分
987
+ inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01 #输入层,784个输入
988
+ targets = np.zeros(out_nodes) + 0.01
989
+ targets[int(all_values[0])] = 0.99
990
+ INPUT = np.array(inputs, ndmin = 2).T # array函数是矩阵生成函数,将输入的inputs转换成二维矩阵,ndmin=2表示二维矩阵
991
+ TARGETS = np.array(targets, ndmin = 2).T # .T表示矩阵的转置,生成后的矩阵的转置矩阵送入变量targets
992
+ # 进行前向传播
993
+ # 利用导入的数据计算进入隐藏层的数据
994
+ hidden_inputs = np.dot(wih, INPUT) # dot()函数是指两个矩阵做点乘
995
+ # 利用激活函数sigmoid计算隐藏层输出的数据
996
+ hidden_outputs = activation_function(hidden_inputs)
997
+ # 利用隐藏层输出的数据计算导入输出层的数据
998
+ final_inputs = np.dot(who, hidden_outputs) # dot()函数是指两个矩阵做点乘
999
+ # 利用激活函数sigmoid计算输出层的输出结果
1000
+ final_outputs = activation_function(final_inputs)
1001
+ # 前向传播结束
1002
+ # 进行反向传播
1003
+ # 计算前向传播得到的输出结果与正确值之间的误差
1004
+ output_errors = TARGETS - final_outputs
1005
+ # 隐藏层的误差是由输出层的误差通过两个层之间的权重矩阵进行分配的,在隐藏层重新结合
1006
+ hidden_errors = np.dot(who.T, output_errors) # 隐藏层与输出层之间的权重矩阵的转置与前向传播的误差矩阵的点乘
1007
+ # 对隐藏层与输出层之间的权重矩阵进行更新迭代
1008
+ who += learningrate * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),np.transpose(hidden_outputs))
1009
+ # 对输入层与隐藏层之间的权重矩阵进行更新迭代
1010
+ wih += learningrate * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(INPUT))
1011
+ train_error += np.sum((output_errors) ** 2)
1012
+ train_error /= len(test_data_list)
1013
+ train_errors.append(train_error)
1014
+
1015
+ # 画出误差曲线
1016
+ plt.plot(train_errors, label='training error')
1017
+ plt.legend()
1018
+ plt.show()
1019
+ ```
1020
+
1021
+ epochs start: 0
1022
+ epochs start: 1
1023
+ epochs start: 2
1024
+ epochs start: 3
1025
+ epochs start: 4
1026
+
1027
+
1028
+
1029
+
1030
+ ![png](output_36_1.png)
1031
+
1032
+
1033
+
1034
+
1035
+ ```python
1036
+ #最终结果,这两个变量就是最终的权重(weights)
1037
+ print(who)
1038
+ print(wih)
1039
+ final_who=who
1040
+ final_wih=wih
1041
+ ```
1042
+
1043
+ [[-1.16611326 -0.4525141 -0.06610833 ... -0.45357449 -0.48939251
1044
+ 0.64537313]
1045
+ [-0.23350166 -0.07640343 -0.33892076 ... -0.42012762 -0.09425477
1046
+ -0.35624211]
1047
+ [ 0.02538154 -0.36034837 -0.31796842 ... -0.03179198 0.24630403
1048
+ 0.53641215]
1049
+ ...
1050
+ [-0.62273744 1.44743377 0.37902492 ... -1.22510993 0.85708252
1051
+ -0.0379783 ]
1052
+ [-0.30649461 -0.45335212 -0.75158325 ... 0.27636151 -0.47017666
1053
+ -0.43715161]
1054
+ [ 0.01993143 -1.11644346 1.10811109 ... 0.39435807 -0.77164373
1055
+ -0.37836149]]
1056
+ [[ 0.01027389 -0.06948278 -0.13336783 ... 0.0431249 0.0116984
1057
+ 0.01118535]
1058
+ [ 0.04093141 0.13349408 0.0447183 ... -0.02876729 -0.08677845
1059
+ -0.05826928]
1060
+ [-0.11370514 -0.04104104 0.05438874 ... -0.00457712 -0.01669163
1061
+ -0.02552346]
1062
+ ...
1063
+ [-0.00480138 0.04369124 -0.07553194 ... 0.09218518 0.02003152
1064
+ 0.0808828 ]
1065
+ [-0.00826098 0.07729079 -0.12576362 ... 0.03445958 0.02413203
1066
+ -0.08935369]
1067
+ [-0.03758297 -0.06222281 0.02554687 ... 0.13169544 0.01547494
1068
+ -0.07650541]]
1069
+
1070
+
1071
+
1072
+ ```python
1073
+ #保存权重
1074
+ np.save("weights", final_who)
1075
+ np.save("weights02",final_wih)
1076
+ ```
1077
+
1078
+
1079
+ ```python
1080
+ #测试
1081
+ ```
1082
+
1083
+
1084
+ ```python
1085
+ #加载权重文件(weights)
1086
+ final_who=np.load("weights.npy")
1087
+ final_wih=np.load("weights02.npy")
1088
+ ```
1089
+
1090
+
1091
+ ```python
1092
+ # Visualize weight matrix wih
1093
+ plt.imshow(final_wih, cmap='coolwarm', aspect='auto')
1094
+ plt.xlabel('Output Node')
1095
+ plt.ylabel('Hidden Node')
1096
+ plt.title('Weight Matrix (Hidden to input)')
1097
+ plt.colorbar()
1098
+ plt.show()
1099
+ ```
1100
+
1101
+
1102
+
1103
+ ![png](output_41_0.png)
1104
+
1105
+
1106
+
1107
+
1108
+ ```python
1109
+ # Visualize weight matrix who
1110
+ plt.imshow(final_who, cmap='coolwarm', aspect='auto')
1111
+ plt.xlabel('Output Node')
1112
+ plt.ylabel('Hidden Node')
1113
+ plt.title('Weight Matrix (Hidden to output)')
1114
+ plt.colorbar()
1115
+ plt.show()
1116
+ ```
1117
+
1118
+
1119
+
1120
+ ![png](output_42_0.png)
1121
+
1122
+
1123
+
1124
+
1125
+ ```python
1126
+ test_data_file = open("mnist_test.csv", 'r')
1127
+ test_data_list = test_data_file.readlines()
1128
+ test_data_file.close()
1129
+ ```
1130
+
1131
+
1132
+ ```python
1133
+ data_serial_num=455
1134
+ all_values = test_data_list[data_serial_num].split(',') # split()函数将第1条数据进行拆分,以‘,’为分界点进行拆分
1135
+ inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
1136
+ #print(inputs)
1137
+ image_array = np.asfarray(all_values[1:]).reshape((28,28)) # asfarray()函数将all_values中的后784个数字进行重新排列
1138
+ # reshape()函数可以对数组进行整型,使其成为28×28的二维数组,asfarry()函数可以使其成为矩阵。
1139
+ plt.imshow(image_array, interpolation = 'nearest') # imshow()函数可以将28×28的矩阵中的数值当做像素值,使其形成图片
1140
+ ```
1141
+
1142
+
1143
+
1144
+
1145
+ <matplotlib.image.AxesImage at 0x7fa3d809b2e0>
1146
+
1147
+
1148
+
1149
+
1150
+
1151
+ ![png](output_44_1.png)
1152
+
1153
+
1154
+
1155
+
1156
+ ```python
1157
+ test_inputs = np.array(inputs, ndmin = 2).T
1158
+ # 以下程序为计算输出结果的程序,与上面前向传播算法一致
1159
+ hidden_inputs = np.dot(final_wih, test_inputs)
1160
+ hidden_outputs = activation_function(hidden_inputs)
1161
+ final_inputs = np.dot(final_who, hidden_outputs)
1162
+ final_outputs = activation_function(final_inputs)
1163
+ print(final_outputs)
1164
+ ```
1165
+
1166
+ [[0.01072488]
1167
+ [0.99333831]
1168
+ [0.00781424]
1169
+ [0.00584866]
1170
+ [0.02362064]
1171
+ [0.01216366]
1172
+ [0.00683059]
1173
+ [0.00921785]
1174
+ [0.00169813]
1175
+ [0.00730339]]
1176
+
1177
+
1178
+
1179
+ ```python
1180
+ # Visualize hidden layer activations
1181
+ #hidden_inputs = hidden_inputs.reshape((20, 10))
1182
+ plt.imshow(hidden_inputs.T, cmap='hot', aspect='auto')
1183
+ plt.xlabel('Hidden Node')
1184
+ plt.ylabel('Sample')
1185
+ plt.title('Hidden Layer Activations')
1186
+ plt.colorbar()
1187
+ plt.show()
1188
+ ```
1189
+
1190
+
1191
+
1192
+ ![png](output_46_0.png)
1193
+
1194
+
1195
+
1196
+
1197
+ ```python
1198
+ # Visualize hidden layer activations
1199
+ plt.imshow(hidden_outputs.T, cmap='hot', aspect='auto')
1200
+ plt.xlabel('Hidden Node')
1201
+ plt.ylabel('Sample')
1202
+ plt.title('Hidden Layer Activations')
1203
+ plt.colorbar()
1204
+ plt.show()
1205
+ ```
1206
+
1207
+
1208
+
1209
+ ![png](output_47_0.png)
1210
+
1211
+
1212
+
1213
+
1214
+ ```python
1215
+ #可视化中间输出
1216
+ print(final_inputs.T)
1217
+ middle_layer_fig = np.asfarray((final_inputs-0.01)/0.99*255.0 )
1218
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
1219
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
1220
+ ```
1221
+
1222
+ [[-4.52440665 5.00469803 -4.84396237 -5.13567656 -3.72173031 -4.39706459
1223
+ -4.97949043 -4.67735268 -6.37652596 -4.91208681]]
1224
+
1225
+
1226
+
1227
+
1228
+
1229
+ <matplotlib.image.AxesImage at 0x7fa3cbe083a0>
1230
+
1231
+
1232
+
1233
+
1234
+
1235
+ ![png](output_48_2.png)
1236
+
1237
+
1238
+
1239
+
1240
+ ```python
1241
+ # Or visualize final outputs as a heatmap
1242
+ plt.imshow(final_inputs, cmap='hot', aspect='auto')
1243
+ plt.xlabel('Input Node')
1244
+ plt.ylabel('Sample')
1245
+ plt.title('Final Inputs')
1246
+ plt.colorbar()
1247
+ plt.show()
1248
+ ```
1249
+
1250
+
1251
+
1252
+ ![png](output_49_0.png)
1253
+
1254
+
1255
+
1256
+
1257
+ ```python
1258
+ # Visualize final layer inputs
1259
+ plt.bar(range(out_nodes), final_inputs.flatten())
1260
+ plt.xlabel('Output Node')
1261
+ plt.ylabel('Input Value')
1262
+ plt.title('Final Layer Inputs')
1263
+ plt.show()
1264
+ ```
1265
+
1266
+
1267
+
1268
+ ![png](output_50_0.png)
1269
+
1270
+
1271
+
1272
+
1273
+ ```python
1274
+ #可视化中间输出
1275
+ print(final_outputs.T)
1276
+ middle_layer_fig = np.asfarray((final_outputs-0.01)/0.99*255.0 )
1277
+ middle_layer_fig = np.asfarray(middle_layer_fig).reshape((1,10))
1278
+ plt.imshow(middle_layer_fig, interpolation = 'nearest')
1279
+ ```
1280
+
1281
+ [[0.01072488 0.99333831 0.00781424 0.00584866 0.02362064 0.01216366
1282
+ 0.00683059 0.00921785 0.00169813 0.00730339]]
1283
+
1284
+
1285
+
1286
+
1287
+
1288
+ <matplotlib.image.AxesImage at 0x7fa3cbcba550>
1289
+
1290
+
1291
+
1292
+
1293
+
1294
+ ![png](output_51_2.png)
1295
+
1296
+
1297
+
1298
+
1299
+ ```python
1300
+ # Or visualize final outputs as a heatmap
1301
+ plt.imshow(final_outputs, cmap='hot', aspect='auto')
1302
+ plt.xlabel('Output Node')
1303
+ plt.ylabel('Sample')
1304
+ plt.title('Final Outputs')
1305
+ plt.colorbar()
1306
+ plt.show()
1307
+ ```
1308
+
1309
+
1310
+
1311
+ ![png](output_52_0.png)
1312
+
1313
+
1314
+
1315
+
1316
+ ```python
1317
+ # Visualize final layer outputs (sigmoid)
1318
+ plt.bar(range(out_nodes), final_outputs.flatten())
1319
+ plt.xlabel('Output Node')
1320
+ plt.ylabel('Input Value')
1321
+ plt.title('Final Layer Outputs')
1322
+ plt.show()
1323
+ ```
1324
+
1325
+
1326
+
1327
+ ![png](output_53_0.png)
1328
+
1329
+
1330
+
1331
+
1332
+ ```python
1333
+ lebal = np.argmax(final_outputs)
1334
+ print(lebal)
1335
+ ```
1336
+
1337
+ 1
1338
+
1339
+
1340
+
1341
+ ```python
1342
+ #模型效果和性能测试
1343
+ ```
1344
+
1345
+
1346
+ ```python
1347
+ # load the mnist test data CSV file into a list
1348
+ # 导入测试集数据
1349
+ test_data_file = open("mnist_test.csv", 'r')
1350
+ test_data_list = test_data_file.readlines()
1351
+ test_data_file.close()
1352
+ # test the neural network
1353
+ # 用query函数对测试集进行检测
1354
+ # go through all the records in the test data set for record in the test_data_list:
1355
+ scorecard = 0 # 得分卡,检测对一个加一分
1356
+ # 计算测试集上的误差
1357
+
1358
+ for record in test_data_list:
1359
+ # split the record by the ',' comas
1360
+ # 将所有测试数据通过逗号分隔开
1361
+ all_values = record.split(',')
1362
+ # correct answer is first value
1363
+ # 正确值为每一条测试数据的第一个数值
1364
+ correct_lebal = int(all_values[0])
1365
+ #print("correct lebal", correct_lebal) # 将正确的数值在屏幕上打印出来
1366
+ # scale and shift the inputs
1367
+ # 对输入数据进行处理,取后784个数据除以255,再乘以0.99,最后加上0。01,是所有的数据都在0.01到1.00之间
1368
+ inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01 #输入层,784个输入
1369
+
1370
+ # query the network
1371
+ # 用query函数对测试集进行检测
1372
+
1373
+ test_inputs = np.array(inputs, ndmin = 2).T
1374
+ # 以下程序为计算输出结果的程序,与上面前向传播算法一致
1375
+ hidden_inputs = np.dot(final_wih, test_inputs)
1376
+ hidden_outputs = activation_function(hidden_inputs)
1377
+ final_inputs = np.dot(final_who, hidden_outputs)
1378
+ final_outputs = activation_function(final_inputs)
1379
+
1380
+ # the index of the highest value corresponds to out label
1381
+ # 得到的数字就是输出结果的最大的数值所对应的标签
1382
+ lebal = np.argmax(final_outputs) # argmax()函数用于找出数值最大的值所对应的标签
1383
+ #print("Output is ", lebal) # 在屏幕上打出最终输出的结果
1384
+ # output image of every digit
1385
+ # 输出每一个数字的图片
1386
+ #image_correct = np.asfarray(all_values[1:]).reshape((28, 28))
1387
+ #plt.imshow(image_correct, cmap = 'Greys', interpolation = 'None')
1388
+ #plt.show()
1389
+ # append correct or incorrect to list
1390
+ if (lebal == correct_lebal):
1391
+ # network's answer matchs correct answer, add 1 to scorecard
1392
+ scorecard += 1
1393
+ else:
1394
+ # network's answer doesn't match correct answer, add 0 to scorecard
1395
+ scorecard += 0
1396
+ pass
1397
+ pass
1398
+
1399
+ # calculate the performance score, the fraction
1400
+ # 计算准确率 得分卡最后的数值/10000(测试集总个数)
1401
+ print("performance = ", scorecard / 10000)
1402
+
1403
+ ```
1404
+
1405
+ performance = 0.9722
1406
+
1407
+
1408
+
1409
+ ```python
1410
+
1411
+ ```