Skip to content

Commit 019aebf

Browse files
committed
Added the SGD example: (mimicker)
1 parent 8b5a633 commit 019aebf

6 files changed

Lines changed: 174 additions & 2 deletions

File tree

Examples/FlappyTest/Game.gd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ func _ready() -> void:
2929

3030
$Players.add_child(player)
3131

32+
GeneticEvolution.show_popup()
3233
spawn_obstacle()
3334

3435

Examples/SGD/mimicker.gd

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
extends Control
2+
3+
## In this example, The AI is trained to mimic whatever input the user is giving.
4+
5+
## The AI is trained using using mini-batch stochastic gradient descent.
6+
7+
var net :Network
8+
9+
func _ready() -> void:
10+
net = Network.new([5, 20, 5])
11+
# NOTE: it took roughly 12 seconds for my computer to train AI under the current settings.
12+
net.SGD(generate_training_data(), 1, 5, 1, generate_training_data().slice(0, 100))
13+
net.add_visualizer($HBoxContainer/visual)
14+
15+
16+
func generate_training_data() -> Array[Array]:
17+
# Yes, I know there are many duplicates in the data array.
18+
var data: Array[Array] = []
19+
for a in 5:
20+
for b in 5:
21+
for c in 5:
22+
for d in 5:
23+
for e in 5:
24+
var sample: Array[Matrix] = [Matrix.new(5, 1)]
25+
sample[0].set_index(a, 0, 1)
26+
sample[0].set_index(b, 0, 1)
27+
sample[0].set_index(c, 0, 1)
28+
sample[0].set_index(d, 0, 1)
29+
sample[0].set_index(e, 0, 1)
30+
# generate corresponding output
31+
sample.append(sample[0].clone())
32+
data.append(sample)
33+
## There are very less entries of samples where only one input is switched on, let's add more
34+
## of them to the mis
35+
for a in 5:
36+
var sample: Array[Matrix] = [Matrix.new(5, 1)]
37+
sample[0].set_index(a, 0, 1)
38+
# generate corresponding output
39+
sample.append(sample[0].clone())
40+
for b in 30:
41+
data.append(sample)
42+
## Now stir the mix, and feed it to the AI
43+
data.shuffle()
44+
return data
45+
46+
47+
func update(button_pressed: bool) -> void:
48+
var inputs: PackedFloat32Array = []
49+
for input in %Inputs.get_children():
50+
inputs.append(input.button_pressed == true)
51+
var out = net.feedforward(inputs).to_array()
52+
for i in %Outputs.get_child_count():
53+
var output: CheckButton = %Outputs.get_child(i)
54+
output.button_pressed = out[i] > 0.7

Examples/SGD/mimicker.tscn

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
[gd_scene load_steps=2 format=3 uid="uid://cmg6es5cad8u2"]
2+
3+
[ext_resource type="Script" path="res://Examples/SGD/mimicker.gd" id="1_c78eo"]
4+
5+
[node name="Mimicker" type="Control"]
6+
layout_mode = 3
7+
anchors_preset = 15
8+
anchor_right = 1.0
9+
anchor_bottom = 1.0
10+
grow_horizontal = 2
11+
grow_vertical = 2
12+
script = ExtResource("1_c78eo")
13+
14+
[node name="HBoxContainer" type="HBoxContainer" parent="."]
15+
layout_mode = 2
16+
offset_right = 1152.0
17+
offset_bottom = 648.0
18+
size_flags_vertical = 3
19+
20+
[node name="Inputs" type="VBoxContainer" parent="HBoxContainer"]
21+
unique_name_in_owner = true
22+
layout_mode = 2
23+
size_flags_horizontal = 3
24+
alignment = 1
25+
26+
[node name="Input1" type="CheckBox" parent="HBoxContainer/Inputs"]
27+
unique_name_in_owner = true
28+
layout_mode = 2
29+
size_flags_horizontal = 4
30+
text = "Human Input 1"
31+
32+
[node name="Input2" type="CheckBox" parent="HBoxContainer/Inputs"]
33+
unique_name_in_owner = true
34+
layout_mode = 2
35+
size_flags_horizontal = 4
36+
text = "Human Input 2"
37+
38+
[node name="Input3" type="CheckBox" parent="HBoxContainer/Inputs"]
39+
unique_name_in_owner = true
40+
layout_mode = 2
41+
size_flags_horizontal = 4
42+
text = "Human Input 3"
43+
44+
[node name="Input4" type="CheckBox" parent="HBoxContainer/Inputs"]
45+
unique_name_in_owner = true
46+
layout_mode = 2
47+
size_flags_horizontal = 4
48+
text = "Human Input 4"
49+
50+
[node name="Input5" type="CheckBox" parent="HBoxContainer/Inputs"]
51+
unique_name_in_owner = true
52+
layout_mode = 2
53+
size_flags_horizontal = 4
54+
text = "Human Input 5"
55+
56+
[node name="visual" type="AspectRatioContainer" parent="HBoxContainer"]
57+
layout_mode = 2
58+
size_flags_horizontal = 3
59+
60+
[node name="Outputs" type="VBoxContainer" parent="HBoxContainer"]
61+
unique_name_in_owner = true
62+
layout_mode = 2
63+
size_flags_horizontal = 3
64+
alignment = 1
65+
66+
[node name="CheckButton" type="CheckButton" parent="HBoxContainer/Outputs"]
67+
layout_mode = 2
68+
size_flags_horizontal = 4
69+
disabled = true
70+
text = "AI Output 1"
71+
72+
[node name="CheckButton2" type="CheckButton" parent="HBoxContainer/Outputs"]
73+
layout_mode = 2
74+
size_flags_horizontal = 4
75+
disabled = true
76+
text = "AI Output 2"
77+
78+
[node name="CheckButton3" type="CheckButton" parent="HBoxContainer/Outputs"]
79+
layout_mode = 2
80+
size_flags_horizontal = 4
81+
disabled = true
82+
text = "AI Output 3"
83+
84+
[node name="CheckButton4" type="CheckButton" parent="HBoxContainer/Outputs"]
85+
layout_mode = 2
86+
size_flags_horizontal = 4
87+
disabled = true
88+
text = "AI Output 4"
89+
90+
[node name="CheckButton5" type="CheckButton" parent="HBoxContainer/Outputs"]
91+
layout_mode = 2
92+
size_flags_horizontal = 4
93+
disabled = true
94+
text = "AI Output 5"
95+
96+
[connection signal="toggled" from="HBoxContainer/Inputs/Input1" to="." method="update"]
97+
[connection signal="toggled" from="HBoxContainer/Inputs/Input2" to="." method="update"]
98+
[connection signal="toggled" from="HBoxContainer/Inputs/Input3" to="." method="update"]
99+
[connection signal="toggled" from="HBoxContainer/Inputs/Input4" to="." method="update"]
100+
[connection signal="toggled" from="HBoxContainer/Inputs/Input5" to="." method="update"]

addons/NeuralNetwork/AutoloadAlgorithms/GeneticEvolution.gd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ func _ready() -> void:
2828
preload("res://addons/NeuralNetwork/NetworkVisualizer/VisualizerPopup.tscn").instantiate()
2929
)
3030
add_child(visualizer_popup)
31+
visualizer_popup.hide()
32+
33+
34+
func show_popup() -> void:
3135
visualizer_popup.popup_centered()
3236

3337

addons/NeuralNetwork/Classes/Network.gd

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ func SGD(
8686
var mini_batch: Array[Array] = training_data.slice(k, k + mini_batch_size)
8787
update_mini_batch(mini_batch, eta)
8888
if test_data:
89-
print("Epoch %s: %s / %s" % [str(j), str(evaluate(test_data)), str(test_data.size())])
89+
print("Epoch %s: %s" % [str(j), str(evaluate_cost(test_data))])
90+
#print("Epoch %s: %s / %s" % [str(j), str(evaluate(test_data)), str(test_data.size())])
9091
else:
9192
print("Epoch %s complete" % str(j))
9293

@@ -172,6 +173,18 @@ func evaluate(test_data: Array[Array]) -> int:
172173
return sum
173174

174175

176+
## A more general wey of visualizing performance of the network. The lower the cost, the better the
177+
## Performance.
178+
func evaluate_cost(test_data: Array[Array]) -> float:
179+
var test_results: Array[Array] = []
180+
var cost_sum: float = 0
181+
for sample: Array[Matrix] in test_data:
182+
var sum_i = feedforward(sample[0].to_array()).subtract_from(sample[1])
183+
var sum_i_square := Array(sum_i.multiply_corresponding(sum_i).to_array())
184+
cost_sum += sum_i_square.reduce(func(accum, number): return accum + number)
185+
return cost_sum / test_data.size()
186+
187+
175188
## Return the vector of partial derivatives
176189
## (partial C_x / partial a) for the output activations.
177190
func cost_derivative(output_activations: Matrix, y: Matrix) -> Matrix:

project.godot

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ config_version=5
1111
[application]
1212

1313
config/name="NeuralNetwork"
14-
run/main_scene="res://Examples/FlappyTest/Game.tscn"
14+
run/main_scene="res://Examples/SGD/mimicker.tscn"
1515
config/features=PackedStringArray("4.2")
1616

1717
[autoload]

0 commit comments

Comments
 (0)