20. Neural Networks and the Simplist XOR Problem#
This was adopted from the PyTorch Tutorials.
Simple supervised machine learning.
http://pytorch.org/tutorials/beginner/pytorch_with_examples.html
20.1. Neural Networks#
Neural networks are the foundation of deep learning, which has revolutionized the
In the mathematical theory of artificial neural networks, the universal approximation theorem states[1] that a feed-forward network with a single hidden layer containing a finite number of neurons (i.e., a multilayer perceptron), can approximate continuous functions on compact subsets of Rn, under mild assumptions on the activation function.
A simple task that Neural Networks can do but simple linear models cannot is called the XOR problem.
The XOR problem involves an output being 1 if either of two inputs is 1, but not both.
20.1.1. Generate Fake Data#
D_in
is the number of dimensions of an input varaible.D_out
is the number of dimentions of an output variable.Here we are learning some special “fake” data that represents the xor problem.
Here, the dv is 1 if either the first or second variable is
# -*- coding: utf-8 -*-
import numpy as np
#This is our independent and dependent variables.
x = np.array([ [0,0,0],[1,0,0],[0,1,0],[0,0,0] ])
y = np.array([[0,1,1,0]]).T
print("Input data:\n",x,"\n Output data:\n",y)
Input data:
[[0 0 0]
[1 0 0]
[0 1 0]
[0 0 0]]
Output data:
[[0]
[1]
[1]
[0]]
20.1.2. A Simple Neural Network#
Here we are going to build a neural network.
First layer (
D_in
)has to be the length of the input.H
is the length of the output.D_out
is 1 as it will be the probability it is a 1.
np.random.seed(seed=83832)
#D_in is the number of input variables.
#H is the hidden dimension.
#D_out is the number of dimensions for the output.
D_in, H, D_out = 3, 2, 1
# Randomly initialize weights og out 2 hidden layer network.
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)
bias = np.random.randn(H, 1)
20.1.4. Update the Weights using Gradient Decent#
Calculate the predited value
Calculate the loss function
Compute the gradients of w1 and w2 with respect to the loss function
Update the weights using the learning rate
learning_rate = .01
for t in range(500):
# Forward pass: compute predicted y
h = x.dot(w1)
#A relu is just the activation.
h_relu = np.maximum(h, 0)
y_pred = h_relu.dot(w2)
# Compute and print loss
loss = np.square(y_pred - y).sum()
print(t, loss)
# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.T.dot(grad_y_pred)
grad_h_relu = grad_y_pred.dot(w2.T)
grad_h = grad_h_relu.copy()
grad_h[h < 0] = 0
grad_w1 = x.T.dot(grad_h)
# Update weights
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
0 10.65792615907139
1 9.10203339892777
2 7.928225580610054
3 7.016030709608875
4 6.289798199184453
5 5.699847385692147
6 5.2123305302347624
7 4.803466247932402
8 4.456102755004962
9 4.1575876890269665
10 3.898402733982808
11 3.671262676836925
12 3.4705056296083194
13 3.291670966818706
14 3.1312013137273507
15 2.9862283397788603
16 2.854416299096229
17 2.733846078586037
18 2.622928124188624
19 2.5203362600714687
20 2.4249568284296723
21 2.335849203166264
22 2.2522148435722413
23 2.173372827242625
24 2.0987403459205147
25 2.0278170362586616
26 1.9601722976944669
27 1.8954349540796849
28 1.8332847664299674
29 1.773445416375481
30 1.7156786642283006
31 1.6597794495384952
32 1.6055717509418743
33 1.5529050598636533
34 1.5016513520352168
35 1.4517024638575724
36 1.402967798918812
37 1.3553723045677533
38 1.3088546702007282
39 1.2633657084618402
40 1.2188668883615361
41 1.175328995740111
42 1.1327309018102432
43 1.0910584249086992
44 1.0503032742251954
45 1.0104620672725408
46 0.9715354153057676
47 0.9335270728583867
48 0.8964431490967778
49 0.8602913798451839
50 0.8250804599443017
51 0.7908194361131566
52 0.7575171607226662
53 0.725181806895671
54 0.6938204451583221
55 0.6634386815182467
56 0.6340403563729178
57 0.6056273030939351
58 0.5781991645253675
59 0.5517532650109411
60 0.5262845349569829
61 0.5017854843733736
62 0.4782462213367543
63 0.45565451090747644
64 0.4339958697177296
65 0.4132536912411343
66 0.39340939665695995
67 0.3744426062333802
68 0.3563313262679901
69 0.339052146830798
70 0.3225804458429691
71 0.3068905953796952
72 0.2919561664924972
73 0.27775012928952875
74 0.26424504547683114
75 0.25141325103470336
76 0.23922702716849936
77 0.2276587581210479
78 0.2166810748552383
79 0.20626698400286633
80 0.1963899818243421
81 0.1870241532299702
82 0.17814425617563434
83 0.16972579196376805
84 0.1617450621557261
85 0.1541792129363326
86 0.147006267868564
87 0.140205150039638
88 0.13375569463319448
89 0.1276386529698942
90 0.1218356890447229
91 0.11632936955756723
92 0.11110314838793904
93 0.1061413464085187
94 0.10142912746853275
95 0.09695247130958222
96 0.09269814410571503
97 0.08865366724822979
98 0.08480728492549391
99 0.0811479309802164
100 0.07766519546208477
101 0.07434929123315356
102 0.0711910209273244
103 0.06818174451394585
104 0.06531334766910445
105 0.06257821111654908
106 0.05996918106326466
107 0.05747954082229276
108 0.05510298368722062
109 0.052833587098539955
110 0.050665788121483174
111 0.04859436023766026
112 0.046614391438508374
113 0.044721263596905425
114 0.042910633083986426
115 0.04117841259093982
116 0.0395207541100794
117 0.03793403302553889
118 0.03641483326129067
119 0.03495993343264216
120 0.03356629394672944
121 0.0322310449976488
122 0.03095147540259279
123 0.029725022226571782
124 0.028549261144892105
125 0.027421897494436413
126 0.0263407579668736
127 0.025303782899147254
128 0.02430901911889927
129 0.023354613304830955
130 0.0224388058243556
131 0.021559925013216318
132 0.02071638186401746
133 0.019906665092819995
134 0.019129336555072554
135 0.01838302698417915
136 0.017666432027935646
137 0.016978308559893886
138 0.016317471244437356
139 0.015682789335971285
140 0.015073183694146064
141 0.014487623998449207
142 0.01392512614681838
143 0.013384749824153785
144 0.012865596227742457
145 0.01236680593765744
146 0.011887556921165793
147 0.01142706266107241
148 0.010984570398752825
149 0.010559359483383497
150 0.010150739819576832
151 0.009758050406265562
152 0.009380657960268358
153 0.009017955618505363
154 0.00866936171332396
155 0.008334318615846222
156 0.008012291642660496
157 0.007702768021557011
158 0.0074052559123513875
159 0.007119283479154988
160 0.00684439801073854
161 0.006580165085898312
162 0.006326167780975099
163 0.0060820059168944065
164 0.005847295343298466
165 0.005621667257523059
166 0.00540476755634021
167 0.00519625621854179
168 0.004995806716578387
169 0.004803105455597478
170 0.004617851238342064
171 0.004439754754478814
172 0.004268538093024131
173 0.004103934276626908
174 0.003945686816550616
175 0.0037935492872733776
176 0.0036472849196956558
177 0.003506666212010064
178 0.0033714745573471235
179 0.0032414998873664088
180 0.0031165403310131914
181 0.002996401887707673
182 0.0028808981142773157
183 0.0027698498249834596
184 0.002663084804030042
185 0.002560437529977551
186 0.0024617489115170118
187 0.0023668660340890185
188 0.0022756419168604115
189 0.002187935279597463
190 0.002103610318998506
191 0.0020225364940712155
192 0.0019445883201617943
193 0.0018696451712621129
194 0.0017975910902400137
195 0.0017283146066554844
196 0.0016617085618411765
197 0.0015976699409418664
198 0.0015360997116213772
199 0.001476902669159437
200 0.0014199872876736073
201 0.0013652655772137235
202 0.0013126529464877815
203 0.0012620680709887451
204 0.0012134327663025218
205 0.0011666718663866913
206 0.0011217131066188356
207 0.0010784870114223695
208 0.0010369267862857035
209 0.000996968213998663
210 0.0009585495549376858
211 0.0009216114512381271
212 0.0008860968346992506
213 0.0008519508382735161
214 0.0008191207109983951
215 0.0007875557362343561
216 0.000757207153078624
217 0.0007280280808296751
218 0.0006999734463822273
219 0.000672999914437815
220 0.0006470658204202973
221 0.000622131105990457
222 0.0005981572570579552
223 0.0005751072441928266
224 0.0005529454653431726
225 0.0005316376907686722
226 0.0005111510101038712
227 0.0004914537814680439
228 0.0004725155825422385
229 0.0004543071635367196
230 0.0004368004019756303
231 0.00041996825922807537
232 0.0004037847387179752
233 0.00038822484574746864
234 0.000373264548871404
235 0.0003588807427627314
236 0.0003450512125110722
237 0.00033175459929894827
238 0.00031897036740236116
239 0.00030667877246449906
240 0.00029486083099329336
241 0.00028349829103548963
242 0.0002725736039817843
243 0.0002620698974593076
244 0.0002519709492694341
245 0.00024226116233050105
246 0.00023292554058671085
247 0.00022394966584582068
248 0.00021531967550979456
249 0.00020702224116398246
250 0.00019904454799161983
251 0.0001913742749818817
252 0.0001839995759008063
253 0.00017690906099565552
254 0.00017009177940448836
255 0.00016353720224369502
256 0.00015723520634729167
257 0.00015117605863298604
258 0.00014535040107069527
259 0.00013974923623037293
260 0.00013436391338677672
261 0.0001291861151597235
262 0.00012420784466916755
263 0.00011942141318526729
264 0.00011481942825437294
265 0.00011039478228253808
266 0.00010614064155901145
267 0.0001020504357026523
268 9.811784751502961e-05
269 9.433680322452288e-05
270 9.07014631063634e-05
271 8.720621246405432e-05
272 8.384565295836376e-05
273 8.06145942704182e-05
274 7.750804608602226e-05
275 7.45212103888556e-05
276 7.16494740506235e-05
277 6.888840170673525e-05
278 6.623372890648624e-05
279 6.36813555272088e-05
280 6.122733944212446e-05
281 5.8867890432272306e-05
282 5.659936433295092e-05
283 5.4418257405766915e-05
284 5.232120092748597e-05
285 5.030495598744348e-05
286 4.8366408485379654e-05
287 4.650256432203307e-05
288 4.4710544775062853e-05
289 4.2987582053132594e-05
290 4.1331015021312185e-05
291 3.9738285091188996e-05
292 3.8206932269354095e-05
293 3.673459135812327e-05
294 3.531898830268317e-05
295 3.3957936678981286e-05
296 3.2649334316941296e-05
297 3.139116005380407e-05
298 3.0181470612568156e-05
299 2.901839760070956e-05
300 2.790014462456529e-05
301 2.6824984514883835e-05
302 2.579125665930827e-05
303 2.479736443763488e-05
304 2.3841772755898297e-05
305 2.2923005675484053e-05
306 2.2039644133597908e-05
307 2.119032375156838e-05
308 2.0373732727611678e-05
309 1.958860981078675e-05
310 1.8833742353025982e-05
311 1.8107964436237772e-05
312 1.7410155071560296e-05
313 1.673923646802122e-05
314 1.6094172367903756e-05
315 1.547396644626215e-05
316 1.4877660772109866e-05
317 1.430433432890405e-05
318 1.3753101592043518e-05
319 1.3223111161177204e-05
320 1.271354444522436e-05
321 1.2223614398052364e-05
322 1.1752564302897218e-05
323 1.1299666603598998e-05
324 1.0864221780898391e-05
325 1.0445557272020895e-05
326 1.0043026431896353e-05
327 9.656007534414035e-06
328 9.283902812157043e-06
329 8.926137533144259e-06
330 8.582159113149512e-06
331 8.251436262223122e-06
332 7.933458164113636e-06
333 7.627733687293064e-06
334 7.33379062640477e-06
335 7.051174972921215e-06
336 6.779450213931473e-06
337 6.518196657925287e-06
338 6.267010786575047e-06
339 6.025504631492694e-06
340 5.793305174999138e-06
341 5.570053773989272e-06
342 5.355405605998847e-06
343 5.149029136614468e-06
344 4.9506056074123485e-06
345 4.759828543619497e-06
346 4.576403280761419e-06
347 4.4000465095373554e-06
348 4.230485838242033e-06
349 4.0674593720484435e-06
350 3.91071530849992e-06
351 3.760011548592869e-06
352 3.615115322845847e-06
353 3.4758028317824277e-06
354 3.3418589002573495e-06
355 3.2130766451173086e-06
356 3.0892571556569434e-06
357 2.9702091863957054e-06
358 2.8557488616903765e-06
359 2.7456993917342825e-06
360 2.6398907994953916e-06
361 2.538159658179872e-06
362 2.4403488388159694e-06
363 2.3463072675564664e-06
364 2.2558896923420854e-06
365 2.1689564585528806e-06
366 2.0853732933059644e-06
367 2.0050110980633565e-06
368 1.9277457492352228e-06
369 1.8534579064655777e-06
370 1.7820328283025376e-06
371 1.7133601949808364e-06
372 1.6473339380218593e-06
373 1.5838520764096164e-06
374 1.5228165590753771e-06
375 1.464133113451991e-06
376 1.40771109986799e-06
377 1.3534633715480287e-06
378 1.3013061400118565e-06
379 1.2511588456580528e-06
380 1.202944033337219e-06
381 1.156587232716991e-06
382 1.1120168432603495e-06
383 1.0691640236341087e-06
384 1.0279625853816399e-06
385 9.883488906893815e-07
386 9.502617540972218e-07
387 9.136423479925375e-07
388 8.784341117501692e-07
389 8.445826643692572e-07
390 8.120357204815001e-07
391 7.807430095922277e-07
392 7.506561984347445e-07
393 7.217288163193481e-07
394 6.939161833544208e-07
395 6.671753414368786e-07
396 6.41464987902722e-07
397 6.167454117308101e-07
398 5.929784322084336e-07
399 5.701273399586136e-07
400 5.481568402408807e-07
401 5.270329984367319e-07
402 5.06723187636041e-07
403 4.871960382424293e-07
404 4.684213895220894e-07
405 4.5037024301790184e-07
406 4.33014717761703e-07
407 4.16328007208233e-07
408 4.002843378327513e-07
409 3.848589293226123e-07
410 3.7002795630127346e-07
411 3.5576851152987916e-07
412 3.4205857052566243e-07
413 3.288769575438337e-07
414 3.162033128703886e-07
415 3.040180613761368e-07
416 2.9230238228100927e-07
417 2.8103818008466216e-07
418 2.7020805661710233e-07
419 2.5979528416658235e-07
420 2.4978377964218444e-07
421 2.4015807973288856e-07
422 2.3090331702399554e-07
423 2.2200519703291648e-07
424 2.1344997613110857e-07
425 2.0522444031625843e-07
426 1.9731588480219806e-07
427 1.897120943962158e-07
428 1.8240132463183187e-07
429 1.7537228362828715e-07
430 1.6861411465060232e-07
431 1.6211637934017274e-07
432 1.5586904159212865e-07
433 1.498624520541246e-07
434 1.4408733322270431e-07
435 1.3853476511318796e-07
436 1.331961714828652e-07
437 1.2806330658368513e-07
438 1.2312824242717036e-07
439 1.1838335653864165e-07
440 1.1382132018416605e-07
441 1.0943508705079577e-07
442 1.0521788236287128e-07
443 1.0116319241884456e-07
444 9.726475452980349e-08
445 9.351654734641305e-08
446 8.991278155889296e-08
447 8.64478909551478e-08
448 8.311652382328898e-08
449 7.991353468584456e-08
450 7.683397635167653e-08
451 7.387309227562576e-08
452 7.102630921117364e-08
453 6.828923014717936e-08
454 6.565762751588525e-08
455 6.312743666418616e-08
456 6.06947495745769e-08
457 5.835580882951291e-08
458 5.6107001807907785e-08
459 5.394485510481166e-08
460 5.1866029167092425e-08
461 4.986731313500184e-08
462 4.7945619882884734e-08
463 4.6097981250896086e-08
464 4.432154346061449e-08
465 4.2613562707360456e-08
466 4.097140092171029e-08
467 3.939252169562097e-08
468 3.7874486364041686e-08
469 3.6414950238919145e-08
470 3.501165898714879e-08
471 3.366244514904944e-08
472 3.236522479040909e-08
473 3.1117994283806455e-08
474 2.991882721375856e-08
475 2.8765871401598003e-08
476 2.7657346044208938e-08
477 2.6591538963808025e-08
478 2.5566803963398962e-08
479 2.4581558284015407e-08
480 2.3634280160019428e-08
481 2.2723506468796883e-08
482 2.1847830470759422e-08
483 2.100589963668148e-08
484 2.0196413558541738e-08
485 1.9418121941025346e-08
486 1.8669822670309218e-08
487 1.7950359957501043e-08
488 1.7258622553235455e-08
489 1.65935420314299e-08
490 1.5954091138984373e-08
491 1.5339282209192226e-08
492 1.4748165636183706e-08
493 1.4179828408244494e-08
494 1.3633392697563574e-08
495 1.3108014504384369e-08
496 1.2602882353487569e-08
497 1.2117216040820382e-08
498 1.1650265428291116e-08
499 1.1201309285195922e-08
Fully connected
20.1.5. Verify the Predictions#
Obtained a predicted value from our model and compare to origional.
pred = np.maximum(x.dot(w1),0).dot(w2)
print (pred, "\n", y)
[[0. ]
[0.99992661]
[1.00007337]
[0. ]]
[[0]
[1]
[1]
[0]]
y
array([[0],
[1],
[1],
[0]])
#We can see that the weights have been updated.
w1
array([[-0.20401151, 1.01377406],
[-0.10186284, 1.01392285],
[ 1.07856887, 0.01873049]])
w2
array([[0.49346731],
[0.98634069]])
# Relu just removes the negative numbers.
h_relu
array([[0. , 0. ],
[0. , 1.01377258],
[0. , 1.01392433],
[0. , 0. ]])