Skip to content

Commit 54ce4e3

Browse files
committed
Add unit tests to check that the neural net can approximate x=y and y=sin(x)
Also rename the unit tests file since all these tests actually belong to the neural net, not the tic tac toe.
1 parent 10de688 commit 54ce4e3

2 files changed

Lines changed: 87 additions & 24 deletions

File tree

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import org.scalatest._
2+
import collection.mutable.Stack
3+
import neuralNet._
4+
import neuralNet.NeuralNetUtilities._
5+
6+
abstract class UnitSpec extends FlatSpec with Matchers with
7+
OptionValues with Inside with Inspectors
8+
9+
class ExampleSpec extends FlatSpec with Matchers {
10+
11+
"A NeuralNet" should "correctly convert a state and action into a featureVector" in {
12+
var featureVetor = neuralNetFeatureVectorForStateAction(List("X", "", "", "", "", "", "" , "", ""), 2)
13+
featureVetor should equal (Array(1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0))
14+
featureVetor = neuralNetFeatureVectorForStateAction(List("X", "", "", "O", "", "O", "" , "", ""), 9)
15+
featureVetor should equal (Array(1.0, 0.0, 0.0, -1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 9.0))
16+
}
17+
18+
it should "be able to learn sin(x)" in {
19+
val neuralNet = new NeuralNet(1, 20, 0.05, 1.0)
20+
var i = 0
21+
while (i < 100000) { // Train
22+
val x = scala.util.Random.nextDouble()
23+
val y = scala.math.sin(x)
24+
neuralNet.train(Array(x), y)
25+
i += 1
26+
}
27+
i = 0
28+
while (i < 1000) { // Test it works to a degree
29+
val x = scala.util.Random.nextDouble()
30+
val y = scala.math.sin(x)
31+
val result = neuralNet.feedForward(Array(x))
32+
var withinRange = false
33+
if (result < y + 0.1 && result > y - 0.1) {
34+
withinRange = true
35+
}
36+
else {
37+
println(s"x = ${x}")
38+
println(s"result = ${result}")
39+
}
40+
withinRange should equal (true)
41+
i += 1
42+
}
43+
while (i < 1000) { // Negative test to check that the test itself isn't broken
44+
val x = scala.util.Random.nextDouble()
45+
val y = scala.math.sin(x)
46+
val result = neuralNet.feedForward(Array(x))
47+
var withinRange = true
48+
if (result < y + 0.01 && result > y - 0.01) {
49+
withinRange = false
50+
}
51+
else {
52+
println(s"x = ${x}")
53+
println(s"result = ${result}")
54+
}
55+
withinRange should equal (false)
56+
i += 1
57+
}
58+
}
59+
60+
it should "be able to learn x=y" in {
61+
val neuralNet = new NeuralNet(1, 10, 0.1, 1.0)
62+
var i = 0
63+
while (i < 100000) { // Train
64+
val x = scala.util.Random.nextDouble()
65+
//println(s"input = ${x}")
66+
val result = neuralNet.feedForward(Array(x))
67+
//println(s"result = ${result}")
68+
neuralNet.train(Array(x), x)
69+
i += 1
70+
}
71+
i = 0
72+
while (i < 1000) {
73+
val x = scala.util.Random.nextDouble()
74+
val result = neuralNet.feedForward(Array(x))
75+
var withinRange = false
76+
if (result < x + 0.1 && result > x - 0.1) {
77+
withinRange = true
78+
}
79+
else {
80+
println(s"x = ${x}")
81+
println(s"result = ${result}")
82+
}
83+
withinRange should equal (true)
84+
i += 1
85+
}
86+
}
87+
}

src/test/scala/ticTacToeTests.scala

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)