Skip to content

Commit 0adb703

Browse files
author
agoroot
committed
added test
1 parent 6ff10d2 commit 0adb703

1 file changed

Lines changed: 328 additions & 0 deletions

File tree

tests/test_fun_dynamics.py

Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
"""
2+
Test the basic/key dynamics of an NGC system constructed with Node(s) and Cable(s)
3+
with synthetic data as a test case.
4+
-------------------------------------------------------------------------------
5+
This code runs the "identity test" for NGC projection and dynamics:
6+
+ create a 3-layer model, w/ tied forward/backward weights and forward weights
7+
initialized to a diagonal, that performs x |-> y w/ relu activation and
8+
identity output
9+
+ set x = 1 and clamp it to input of projection graph, should return 1
10+
+ clamp x to ngc graph to input and output/targets states, when run for a
11+
number of settling steps, should return 1 and remain in equilibrium where
12+
all internal error nodes are exactly 0.
13+
Repeat the above test for a 3-layer model initialized exactly the same except
14+
with untied forward/backward weights (backward weights initialized to diagonal)
15+
16+
For both model types above, check (after simulating the settling process) that
17+
the synaptic updates for every single weight and bias are exactly matrices of zeros.
18+
-------------------------------------------------------------------------------
19+
20+
The following models are tested:
21+
a 3-layer x->y NGC with tied forward/error weights
22+
a 3-layer x->y NGC with untied forward/error weights
23+
24+
This (non-exhaustive) test script checks for qualitative irregularities in each
25+
model's behavior and, indirectly, the functioning of ngc-learn's nodes and cables.
26+
Please read the documentations in /docs/ for an overview and description/use of
27+
Nodes and Cables as well as for practical use-cases through the
28+
demonstrations (under the /examples directory).
29+
30+
As this is strictly a qualitative set of tests, it is up to the developer / user
31+
to examine for specific irregularities.
32+
"""
33+
34+
import os
35+
import sys, getopt, optparse
36+
import pickle
37+
sys.path.insert(0, '../')
38+
import tensorflow as tf
39+
import numpy as np
40+
import time
41+
42+
# import general simulation utilities
43+
from ngclearn.utils.config import Config
44+
import ngclearn.utils.transform_utils as transform
45+
import ngclearn.utils.metric_utils as metric
46+
import ngclearn.utils.io_utils as io_tools
47+
from ngclearn.utils.data_utils import DataLoader
48+
49+
from ngclearn.engine.nodes.snode import SNode
50+
from ngclearn.engine.nodes.enode import ENode
51+
from ngclearn.engine.ngc_graph import NGCGraph
52+
53+
from ngclearn.engine.nodes.fnode import FNode
54+
from ngclearn.engine.proj_graph import ProjectionGraph
55+
56+
seed = 69
57+
tf.random.set_seed(seed=seed)
58+
np.random.seed(seed)
59+
60+
# set up parameters of tests
61+
62+
x = tf.ones([1,10])
63+
x_dim = x.shape[1]
64+
z3_dim = x_dim
65+
z2_dim = z3_dim
66+
z1_dim = z2_dim
67+
z0_dim = z1_dim
68+
69+
seed = 69
70+
beta = 1 # must fix step to 1.0 for this test
71+
integrate_cfg = {"integrate_type" : "euler", "use_dfx" : True}
72+
73+
print("#######################################################################")
74+
print(" > Testing a proxy NGC graph w/ tied weights")
75+
# set up system nodes
76+
z3 = SNode(name="z3", dim=z3_dim, beta=beta, leak=0, act_fx="identity",
77+
integrate_kernel=integrate_cfg)
78+
mu2 = SNode(name="mu2", dim=z2_dim, act_fx="identity", zeta=0.0)
79+
e2 = ENode(name="e2", dim=z2_dim)
80+
z2 = SNode(name="z2", dim=z2_dim, beta=beta, leak=0, act_fx="relu",
81+
integrate_kernel=integrate_cfg)
82+
mu1 = SNode(name="mu1", dim=z1_dim, act_fx="identity", zeta=0.0)
83+
e1 = ENode(name="e1", dim=z1_dim)
84+
z1 = SNode(name="z1", dim=z1_dim, beta=beta, leak=0, act_fx="relu",
85+
integrate_kernel=integrate_cfg)
86+
mu0 = SNode(name="mu0", dim=z0_dim, act_fx="identity", zeta=0.0)
87+
e0 = ENode(name="e0", dim=z0_dim)
88+
z0 = SNode(name="z0", dim=z0_dim, beta=beta, leak=0.0,
89+
integrate_kernel=integrate_cfg)
90+
91+
# create cable wiring scheme relating nodes to one another
92+
dcable_cfg = {"type": "dense", "has_bias": True,
93+
"init" : ("diagonal",1), "seed" : seed} #classic_glorot
94+
pos_scable_cfg = {"type": "simple", "coeff": 1.0}
95+
neg_scable_cfg = {"type": "simple", "coeff": -1.0}
96+
97+
z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
98+
mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
99+
z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
100+
e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z3_mu2,"symm_tied"))
101+
e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
102+
103+
z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
104+
mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
105+
z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
106+
e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z2_mu1,"symm_tied"))
107+
e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
108+
109+
z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
110+
mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
111+
z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
112+
e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z1_mu0,"symm_tied"))
113+
e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
114+
115+
# set up update rules and make relevant edges aware of these
116+
z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)"))
117+
z2_mu1.set_update_rule(preact=(z2,"phi(z)"), postact=(e1,"phi(z)"))
118+
z1_mu0.set_update_rule(preact=(z1,"phi(z)"), postact=(e0,"phi(z)"))
119+
120+
# Set up graph - execution cycle/order
121+
ngc_model = NGCGraph(K=10, name="gncn_t1_ffm")
122+
ngc_model.proj_update_mag = -1.0
123+
ngc_model.proj_weight_mag = -1.0
124+
ngc_model.set_cycle(nodes=[z3,z2,z1,z0])
125+
ngc_model.set_cycle(nodes=[mu2,mu1,mu0])
126+
ngc_model.set_cycle(nodes=[e2,e1,e0])
127+
ngc_model.apply_constraints()
128+
129+
# build this NGC model's sampling graph
130+
z3_dim = ngc_model.getNode("z3").dim
131+
z2_dim = ngc_model.getNode("z2").dim
132+
z1_dim = ngc_model.getNode("z1").dim
133+
z0_dim = ngc_model.getNode("z0").dim
134+
# Set up complementary sampling graph to use in conjunction w/ NGC-graph
135+
s3 = FNode(name="s3", dim=z3_dim, act_fx="identity")
136+
s2 = FNode(name="s2", dim=z2_dim, act_fx="relu")
137+
s1 = FNode(name="s1", dim=z1_dim, act_fx="relu")
138+
s0 = FNode(name="s0", dim=z0_dim, act_fx="identity")
139+
s3_s2 = s3.wire_to(s2, src_var="phi(z)", dest_var="dz", point_to_path=z3_mu2)
140+
s2_s1 = s2.wire_to(s1, src_var="phi(z)", dest_var="dz", point_to_path=z2_mu1)
141+
s1_s0 = s1.wire_to(s0, src_var="phi(z)", dest_var="dz", point_to_path=z1_mu0)
142+
sampler = ProjectionGraph()
143+
sampler.set_cycle(nodes=[s3,s2,s1,s0])
144+
145+
# test projection graph
146+
print("----------------")
147+
print(" > Checking ancestral projection graph:")
148+
readouts = sampler.project(
149+
clamped_vars=[("s3","z",x)],
150+
readout_vars=[("s0","phi(z)")]
151+
)
152+
x_sample = readouts[0][2]
153+
154+
print(" => Test for: x = 1 = s2.z = s2.phi(z)")
155+
s3_z = sampler.extract("s3","z")
156+
s3_phi = sampler.extract("s3","phi(z)")
157+
np.testing.assert_array_equal(x.numpy(), s3_z.numpy())
158+
np.testing.assert_array_equal(x.numpy(), s3_phi.numpy())
159+
print(" PASS!")
160+
161+
print(" => Test for: x = 1 = s1.z = s1.phi(z)")
162+
s1_z = sampler.extract("s1","z")
163+
s1_phi = sampler.extract("s1","phi(z)")
164+
np.testing.assert_array_equal(x.numpy(), s1_z.numpy())
165+
np.testing.assert_array_equal(x.numpy(), s1_phi.numpy())
166+
print(" PASS!")
167+
168+
print(" => Test for: x = 1 = s0.z = s0.phi(z)")
169+
s0_z = sampler.extract("s0","z")
170+
s0_phi = sampler.extract("s0","phi(z)")
171+
np.testing.assert_array_equal(x.numpy(), s0_z.numpy())
172+
np.testing.assert_array_equal(x.numpy(), s0_phi.numpy())
173+
print(" PASS!")
174+
175+
print(" => Test for: x = x_sample")
176+
print("Expected: ",x.numpy())
177+
print(" Output: ",x_sample.numpy())
178+
np.testing.assert_array_equal(x.numpy(), x_sample.numpy())
179+
print(" PASS!")
180+
181+
182+
# test NGC simulation object
183+
print("----------------")
184+
print(" > Checking NGC simulation object:")
185+
readouts = ngc_model.settle(
186+
clamped_vars=[("z3","z",x),("z0","z",x)],
187+
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
188+
)
189+
x_hat = readouts[0][2]
190+
191+
print(" => Test for: 0 = e2.z = e2.phi(z)")
192+
target_value = np.zeros([1,z2_dim])
193+
e2_z = ngc_model.extract("e2","z")
194+
e2_phi = ngc_model.extract("e2","phi(z)")
195+
np.testing.assert_array_equal(target_value, e2_z.numpy())
196+
np.testing.assert_array_equal(target_value, e2_phi.numpy())
197+
print(" PASS!")
198+
199+
print(" => Test for: 0 = e1.z = e1.phi(z)")
200+
target_value = np.zeros([1,z1_dim])
201+
e1_z = ngc_model.extract("e1","z")
202+
e1_phi = ngc_model.extract("e1","phi(z)")
203+
np.testing.assert_array_equal(target_value, e1_z.numpy())
204+
np.testing.assert_array_equal(target_value, e1_phi.numpy())
205+
print(" PASS!")
206+
207+
print(" => Test for: 0 = e0.z = e0.phi(z)")
208+
target_value = np.zeros([1,z0_dim])
209+
e0_z = ngc_model.extract("e0","z")
210+
e0_phi = ngc_model.extract("e0","phi(z)")
211+
np.testing.assert_array_equal(target_value, e0_z.numpy())
212+
np.testing.assert_array_equal(target_value, e0_phi.numpy())
213+
print(" PASS!")
214+
215+
print(" => Test for: x = x_hat")
216+
print("Expected: ",x.numpy())
217+
print(" Output: ",x_hat.numpy())
218+
np.testing.assert_array_equal(x.numpy(), x_hat.numpy())
219+
print(" PASS!")
220+
221+
print(" => Test for update calculation: all dx should be = 0")
222+
delta = ngc_model.calc_updates()
223+
for i in range(len(delta)):
224+
target_dx = ngc_model.theta[i] * 0
225+
dx = delta[i]
226+
np.testing.assert_array_equal(target_dx.numpy(), dx.numpy())
227+
print(" PASS! (for all {} dx calculations)".format(len(delta)))
228+
229+
print("#######################################################################")
230+
231+
print("#######################################################################")
232+
print(" > Testing a proxy NGC graph w/ untied weights")
233+
# set up system nodes
234+
z3 = SNode(name="z3", dim=z3_dim, beta=beta, leak=0, act_fx="identity",
235+
integrate_kernel=integrate_cfg)
236+
mu2 = SNode(name="mu2", dim=z2_dim, act_fx="identity", zeta=0.0)
237+
e2 = ENode(name="e2", dim=z2_dim)
238+
z2 = SNode(name="z2", dim=z2_dim, beta=beta, leak=0, act_fx="relu",
239+
integrate_kernel=integrate_cfg)
240+
mu1 = SNode(name="mu1", dim=z1_dim, act_fx="identity", zeta=0.0)
241+
e1 = ENode(name="e1", dim=z1_dim)
242+
z1 = SNode(name="z1", dim=z1_dim, beta=beta, leak=0, act_fx="relu",
243+
integrate_kernel=integrate_cfg)
244+
mu0 = SNode(name="mu0", dim=z0_dim, act_fx="identity", zeta=0.0)
245+
e0 = ENode(name="e0", dim=z0_dim)
246+
z0 = SNode(name="z0", dim=z0_dim, beta=beta, leak=0.0,
247+
integrate_kernel=integrate_cfg)
248+
249+
# create cable wiring scheme relating nodes to one another
250+
z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
251+
mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
252+
z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
253+
e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
254+
e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
255+
256+
z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
257+
mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
258+
z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
259+
e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
260+
e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
261+
262+
z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg)
263+
mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg)
264+
z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg)
265+
e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg)
266+
e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg)
267+
268+
# set up update rules and make relevant edges aware of these
269+
z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)"))
270+
z2_mu1.set_update_rule(preact=(z2,"phi(z)"), postact=(e1,"phi(z)"))
271+
z1_mu0.set_update_rule(preact=(z1,"phi(z)"), postact=(e0,"phi(z)"))
272+
273+
# Set up graph - execution cycle/order
274+
ngc_model = NGCGraph(K=10, name="gncn_t1_ffm")
275+
ngc_model.proj_update_mag = -1.0
276+
ngc_model.proj_weight_mag = -1.0
277+
ngc_model.set_cycle(nodes=[z3,z2,z1,z0])
278+
ngc_model.set_cycle(nodes=[mu2,mu1,mu0])
279+
ngc_model.set_cycle(nodes=[e2,e1,e0])
280+
ngc_model.apply_constraints()
281+
282+
print("----------------")
283+
print(" > Checking NGC simulation object:")
284+
readouts = ngc_model.settle(
285+
clamped_vars=[("z3","z",x),("z0","z",x)],
286+
readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")]
287+
)
288+
x_hat = readouts[0][2]
289+
290+
print(" => Test for: 0 = e2.z = e2.phi(z)")
291+
target_value = np.zeros([1,z2_dim])
292+
e2_z = ngc_model.extract("e2","z")
293+
e2_phi = ngc_model.extract("e2","phi(z)")
294+
np.testing.assert_array_equal(target_value, e2_z.numpy())
295+
np.testing.assert_array_equal(target_value, e2_phi.numpy())
296+
print(" PASS!")
297+
298+
print(" => Test for: 0 = e1.z = e1.phi(z)")
299+
target_value = np.zeros([1,z1_dim])
300+
e1_z = ngc_model.extract("e1","z")
301+
e1_phi = ngc_model.extract("e1","phi(z)")
302+
np.testing.assert_array_equal(target_value, e1_z.numpy())
303+
np.testing.assert_array_equal(target_value, e1_phi.numpy())
304+
print(" PASS!")
305+
306+
print(" => Test for: 0 = e0.z = e0.phi(z)")
307+
target_value = np.zeros([1,z0_dim])
308+
e0_z = ngc_model.extract("e0","z")
309+
e0_phi = ngc_model.extract("e0","phi(z)")
310+
np.testing.assert_array_equal(target_value, e0_z.numpy())
311+
np.testing.assert_array_equal(target_value, e0_phi.numpy())
312+
print(" PASS!")
313+
314+
print(" => Test for: x = x_hat")
315+
print("Expected: ",x.numpy())
316+
print(" Output: ",x_hat.numpy())
317+
np.testing.assert_array_equal(x.numpy(), x_hat.numpy())
318+
print(" PASS!")
319+
320+
print(" => Test for update calculation: all dx should be = 0")
321+
delta = ngc_model.calc_updates()
322+
for i in range(len(delta)):
323+
target_dx = ngc_model.theta[i] * 0
324+
dx = delta[i]
325+
np.testing.assert_array_equal(target_dx.numpy(), dx.numpy())
326+
print(" PASS! (for all {} dx calculations)".format(len(delta)))
327+
328+
print("#######################################################################")

0 commit comments

Comments
 (0)