Skip to content

Commit 814564a

Browse files
committed
Extend the plotting to handle stalemates and reporting wins for multiple agents
1 parent c25a85d commit 814564a

1 file changed

Lines changed: 52 additions & 11 deletions

File tree

ticTacToe.scala

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// TODO: Implement SARSA(lambda)
77
// TODO: Decrease epsilon over time. In the neural network case, potentially increase it in hopes of jumping out of local optima.
88
// TODO: Improve the neural network's ability to approximate the value function
9+
// TODO: Implement a switch to turn off learning altogether, not merely exploration
910

1011
// Standard Library
1112
import java.awt.Graphics
@@ -106,9 +107,10 @@ object TicTacToeLearning {
106107

107108
object PlotGenerator {
108109
def generateLearningCurves() {
109-
val settings = List(/*(25000, 300, true, false, true, s"Tabular Learner vs. Random Agent, epsilon=${Parameters.epsilon} alpha=${Parameters.tabularAlpha}", "tabular_randomStart.pdf"),*/
110-
/*(100000, 200, false, false, true, s"Neural Net vs. Random Agent, epsilon=${Parameters.epsilon} alpha=${Parameters.neuralAlpha} gamma=0.2", "neural_randomStart.pdf"),*/
111-
(40000, 100, false, false, true, s"Neural Net vs. Random Agent, epsilon=${Parameters.epsilon} learningAlpha=${Parameters.neuralValueLearningAlpha} netAlpha=${Parameters.neuralNetAlpha} gamma=${Parameters.gamma} ${Parameters.neuralNumberHiddenNeurons} hidden neurons ${Parameters.neuralInitialBias} initial bias", "neural_vs_neural.pdf"))
110+
val settings = List((25000, 200, true, false, true, s"Tabular vs. Random Agent, epsilon=${Parameters.epsilon} alpha=${Parameters.tabularAlpha}", "tabularVrandom.pdf", 1),
111+
(50000, 100, false, false, true, s"Neural vs. Random Agent, epsilon=${Parameters.epsilon} learningAlpha=${Parameters.neuralValueLearningAlpha} netAlpha=${Parameters.neuralNetAlpha} gamma=${Parameters.gamma} ${Parameters.neuralNumberHiddenNeurons} hidden neurons ${Parameters.neuralInitialBias} initialBias", "neuralVrandom.pdf", 1),*/
112+
(4000, 150, true, false, false, s"Tabular vs. Tabular, epsilon=${Parameters.epsilon} alpha=${Parameters.tabularAlpha}", "tabularVtabular.pdf", 2),
113+
(40000, 100, false, false, false, s"Neural vs. Neural, epsilon=${Parameters.epsilon} learningAlpha=${Parameters.neuralValueLearningAlpha} netAlpha=${Parameters.neuralNetAlpha} gamma=${Parameters.gamma} ${Parameters.neuralNumberHiddenNeurons} hidden neurons ${Parameters.neuralInitialBias} initial bias", "neuralVneural.pdf", 3))
112114

113115
for (setting <- settings) {
114116
val numberEpisodes = setting._1
@@ -118,6 +120,7 @@ object TicTacToeLearning {
118120
val playerORandom = setting._5
119121
val title = setting._6
120122
val filename = setting._7
123+
val plotting = setting._8 // 1 if I'm plotting player X wins 2 if I'm plotting stalemates and 3 if I'm plotting both player 3 and player 4 wins
121124

122125
var i = 0
123126
val episodeNumbers : Seq[Double] = Seq.fill(numberEpisodes){0.0}
@@ -126,29 +129,63 @@ object TicTacToeLearning {
126129
i += 1
127130
}
128131
var iteration = 0
129-
val finalResults : Seq[Double] = Seq.fill(numberEpisodes){0.0}
132+
val finalResults1 : Seq[Double] = Seq.fill(numberEpisodes){0.0}
133+
val finalResults2 : Seq[Double] = Seq.fill(numberEpisodes){0.0}
130134
while (iteration < numberIterations) {
131135
println(s"Iteration ${iteration}/${numberIterations}")
132136
val results = playTrainingSession(numberEpisodes, tabular, playerXRandom, playerORandom, Parameters.epsilon)
133137
var i = 0
134138
for (result <- results) {
135-
finalResults(i) = finalResults(i) + result
139+
if (plotting == 1) {
140+
if (result == 1) {
141+
finalResults1(i) = finalResults1(i) + result
142+
}
143+
}
144+
else if (plotting == 2) {
145+
if (result == 0) {
146+
finalResults1(i) = finalResults1(i) + 1
147+
}
148+
}
149+
else if (plotting == 3) {
150+
if (result == 1) {
151+
finalResults1(i) = finalResults1(i) + 1
152+
}
153+
else if (result == -1) {
154+
finalResults2(i) = finalResults2(i) + 1
155+
}
156+
}
136157
i += 1
137158
}
138159
iteration += 1
139160
}
140161

141162
i = 0
142-
for (result <- finalResults) {
143-
finalResults(i) = finalResults(i) / numberIterations * 100.0
163+
for (result <- finalResults1) {
164+
finalResults1(i) = finalResults1(i).toDouble / numberIterations.toDouble * 100.0
165+
i += 1
166+
}
167+
i = 0
168+
for (result <- finalResults2) {
169+
finalResults2(i) = finalResults2(i).toDouble / numberIterations.toDouble * 100.0
144170
i += 1
145171
}
146172

147173
val f = Figure()
148174
val p = f.subplot(0)
149-
p += plot(episodeNumbers, finalResults, '.')
175+
if (plotting == 1 || plotting == 2) {
176+
p += plot(episodeNumbers, finalResults1, '.')
177+
}
178+
else if (plotting == 3) {
179+
p += plot(episodeNumbers, finalResults1, '.')
180+
p += plot(episodeNumbers, finalResults2, '.')
181+
}
150182
p.xlabel = "Episodes"
151-
p.ylabel = s"% wins out of ${numberIterations.toInt} iterations"
183+
if (plotting == 1 || plotting == 3) {
184+
p.ylabel = s"% wins out of ${numberIterations.toInt} iterations"
185+
}
186+
else {
187+
p.ylabel = s"% stalemates out of ${numberIterations.toInt} iterations"
188+
}
152189
p.title = title
153190
f.saveas(filename)
154191
}
@@ -170,9 +207,13 @@ object TicTacToeLearning {
170207

171208
def playEpisode(ticTacToeWorld : TicTacToeWorld, epsilon : Double, collectingDataFor : String) : Double = {
172209
var episodeOutcome = -2.0
173-
while (episodeOutcome == -2.0) {
210+
while (episodeOutcome == -2.0) { // Train with epsilon
174211
episodeOutcome = iterateGameStep(ticTacToeWorld, epsilon, None, collectingDataFor)
175212
}
213+
episodeOutcome = -2.0
214+
while (episodeOutcome == -2.0) { // Test run with epsilon = 0
215+
episodeOutcome = iterateGameStep(ticTacToeWorld, 0.0, None, collectingDataFor)
216+
}
176217
return episodeOutcome
177218
}
178219

@@ -187,7 +228,7 @@ object TicTacToeLearning {
187228
returnValue = 1.0
188229
}
189230
else if (environment.playerWon(environment.getOtherAgent(ticTacToeWorld.agent1))) {
190-
returnValue = 0.0
231+
returnValue = -1.0
191232
}
192233
else {
193234
returnValue = 0.0

0 commit comments

Comments
 (0)