|
| 1 | +""" |
| 2 | +Test the basic/key dynamics of an NGC system constructed with Node(s) and Cable(s) |
| 3 | +with synthetic data as a test case. |
| 4 | +------------------------------------------------------------------------------- |
| 5 | +This code runs the "identity test" for NGC projection and dynamics: |
| 6 | ++ create a 3-layer model, w/ tied forward/backward weights and forward weights |
| 7 | + initialized to a diagonal, that performs x |-> y w/ relu activation and |
| 8 | + identity output |
| 9 | ++ set x = 1 and clamp it to input of projection graph, should return 1 |
| 10 | ++ clamp x to ngc graph to input and output/targets states, when run for a |
| 11 | + number of settling steps, should return 1 and remain in equilibrium where |
| 12 | + all internal error nodes are exactly 0. |
| 13 | +Repeat the above test for a 3-layer model initialized exactly the same except |
| 14 | +with untied forward/backward weights (backward weights initialized to diagonal) |
| 15 | +
|
| 16 | +For both model types above, check (after simulating the settling process) that |
| 17 | +the synaptic updates for every single weight and bias are exactly matrices of zeros. |
| 18 | +------------------------------------------------------------------------------- |
| 19 | +
|
| 20 | +The following models are tested: |
| 21 | +a 3-layer x->y NGC with tied forward/error weights |
| 22 | +a 3-layer x->y NGC with untied forward/error weights |
| 23 | +
|
| 24 | +This (non-exhaustive) test script checks for qualitative irregularities in each |
| 25 | +model's behavior and, indirectly, the functioning of ngc-learn's nodes and cables. |
| 26 | +Please read the documentations in /docs/ for an overview and description/use of |
| 27 | +Nodes and Cables as well as for practical use-cases through the |
| 28 | +demonstrations (under the /examples directory). |
| 29 | +
|
| 30 | +As this is strictly a qualitative set of tests, it is up to the developer / user |
| 31 | +to examine for specific irregularities. |
| 32 | +""" |
| 33 | + |
| 34 | +import os |
| 35 | +import sys, getopt, optparse |
| 36 | +import pickle |
| 37 | +sys.path.insert(0, '../') |
| 38 | +import tensorflow as tf |
| 39 | +import numpy as np |
| 40 | +import time |
| 41 | + |
| 42 | +# import general simulation utilities |
| 43 | +from ngclearn.utils.config import Config |
| 44 | +import ngclearn.utils.transform_utils as transform |
| 45 | +import ngclearn.utils.metric_utils as metric |
| 46 | +import ngclearn.utils.io_utils as io_tools |
| 47 | +from ngclearn.utils.data_utils import DataLoader |
| 48 | + |
| 49 | +from ngclearn.engine.nodes.snode import SNode |
| 50 | +from ngclearn.engine.nodes.enode import ENode |
| 51 | +from ngclearn.engine.ngc_graph import NGCGraph |
| 52 | + |
| 53 | +from ngclearn.engine.nodes.fnode import FNode |
| 54 | +from ngclearn.engine.proj_graph import ProjectionGraph |
| 55 | + |
| 56 | +seed = 69 |
| 57 | +tf.random.set_seed(seed=seed) |
| 58 | +np.random.seed(seed) |
| 59 | + |
| 60 | +# set up parameters of tests |
| 61 | + |
| 62 | +x = tf.ones([1,10]) |
| 63 | +x_dim = x.shape[1] |
| 64 | +z3_dim = x_dim |
| 65 | +z2_dim = z3_dim |
| 66 | +z1_dim = z2_dim |
| 67 | +z0_dim = z1_dim |
| 68 | + |
| 69 | +seed = 69 |
| 70 | +beta = 1 # must fix step to 1.0 for this test |
| 71 | +integrate_cfg = {"integrate_type" : "euler", "use_dfx" : True} |
| 72 | + |
| 73 | +print("#######################################################################") |
| 74 | +print(" > Testing a proxy NGC graph w/ tied weights") |
| 75 | +# set up system nodes |
| 76 | +z3 = SNode(name="z3", dim=z3_dim, beta=beta, leak=0, act_fx="identity", |
| 77 | + integrate_kernel=integrate_cfg) |
| 78 | +mu2 = SNode(name="mu2", dim=z2_dim, act_fx="identity", zeta=0.0) |
| 79 | +e2 = ENode(name="e2", dim=z2_dim) |
| 80 | +z2 = SNode(name="z2", dim=z2_dim, beta=beta, leak=0, act_fx="relu", |
| 81 | + integrate_kernel=integrate_cfg) |
| 82 | +mu1 = SNode(name="mu1", dim=z1_dim, act_fx="identity", zeta=0.0) |
| 83 | +e1 = ENode(name="e1", dim=z1_dim) |
| 84 | +z1 = SNode(name="z1", dim=z1_dim, beta=beta, leak=0, act_fx="relu", |
| 85 | + integrate_kernel=integrate_cfg) |
| 86 | +mu0 = SNode(name="mu0", dim=z0_dim, act_fx="identity", zeta=0.0) |
| 87 | +e0 = ENode(name="e0", dim=z0_dim) |
| 88 | +z0 = SNode(name="z0", dim=z0_dim, beta=beta, leak=0.0, |
| 89 | + integrate_kernel=integrate_cfg) |
| 90 | + |
| 91 | +# create cable wiring scheme relating nodes to one another |
| 92 | +dcable_cfg = {"type": "dense", "has_bias": True, |
| 93 | + "init" : ("diagonal",1), "seed" : seed} #classic_glorot |
| 94 | +pos_scable_cfg = {"type": "simple", "coeff": 1.0} |
| 95 | +neg_scable_cfg = {"type": "simple", "coeff": -1.0} |
| 96 | + |
| 97 | +z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 98 | +mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 99 | +z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 100 | +e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z3_mu2,"symm_tied")) |
| 101 | +e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 102 | + |
| 103 | +z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 104 | +mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 105 | +z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 106 | +e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z2_mu1,"symm_tied")) |
| 107 | +e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 108 | + |
| 109 | +z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 110 | +mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 111 | +z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 112 | +e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", mirror_path_kernel=(z1_mu0,"symm_tied")) |
| 113 | +e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 114 | + |
| 115 | +# set up update rules and make relevant edges aware of these |
| 116 | +z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)")) |
| 117 | +z2_mu1.set_update_rule(preact=(z2,"phi(z)"), postact=(e1,"phi(z)")) |
| 118 | +z1_mu0.set_update_rule(preact=(z1,"phi(z)"), postact=(e0,"phi(z)")) |
| 119 | + |
| 120 | +# Set up graph - execution cycle/order |
| 121 | +ngc_model = NGCGraph(K=10, name="gncn_t1_ffm") |
| 122 | +ngc_model.proj_update_mag = -1.0 |
| 123 | +ngc_model.proj_weight_mag = -1.0 |
| 124 | +ngc_model.set_cycle(nodes=[z3,z2,z1,z0]) |
| 125 | +ngc_model.set_cycle(nodes=[mu2,mu1,mu0]) |
| 126 | +ngc_model.set_cycle(nodes=[e2,e1,e0]) |
| 127 | +ngc_model.apply_constraints() |
| 128 | + |
| 129 | +# build this NGC model's sampling graph |
| 130 | +z3_dim = ngc_model.getNode("z3").dim |
| 131 | +z2_dim = ngc_model.getNode("z2").dim |
| 132 | +z1_dim = ngc_model.getNode("z1").dim |
| 133 | +z0_dim = ngc_model.getNode("z0").dim |
| 134 | +# Set up complementary sampling graph to use in conjunction w/ NGC-graph |
| 135 | +s3 = FNode(name="s3", dim=z3_dim, act_fx="identity") |
| 136 | +s2 = FNode(name="s2", dim=z2_dim, act_fx="relu") |
| 137 | +s1 = FNode(name="s1", dim=z1_dim, act_fx="relu") |
| 138 | +s0 = FNode(name="s0", dim=z0_dim, act_fx="identity") |
| 139 | +s3_s2 = s3.wire_to(s2, src_var="phi(z)", dest_var="dz", point_to_path=z3_mu2) |
| 140 | +s2_s1 = s2.wire_to(s1, src_var="phi(z)", dest_var="dz", point_to_path=z2_mu1) |
| 141 | +s1_s0 = s1.wire_to(s0, src_var="phi(z)", dest_var="dz", point_to_path=z1_mu0) |
| 142 | +sampler = ProjectionGraph() |
| 143 | +sampler.set_cycle(nodes=[s3,s2,s1,s0]) |
| 144 | + |
| 145 | +# test projection graph |
| 146 | +print("----------------") |
| 147 | +print(" > Checking ancestral projection graph:") |
| 148 | +readouts = sampler.project( |
| 149 | + clamped_vars=[("s3","z",x)], |
| 150 | + readout_vars=[("s0","phi(z)")] |
| 151 | + ) |
| 152 | +x_sample = readouts[0][2] |
| 153 | + |
| 154 | +print(" => Test for: x = 1 = s2.z = s2.phi(z)") |
| 155 | +s3_z = sampler.extract("s3","z") |
| 156 | +s3_phi = sampler.extract("s3","phi(z)") |
| 157 | +np.testing.assert_array_equal(x.numpy(), s3_z.numpy()) |
| 158 | +np.testing.assert_array_equal(x.numpy(), s3_phi.numpy()) |
| 159 | +print(" PASS!") |
| 160 | + |
| 161 | +print(" => Test for: x = 1 = s1.z = s1.phi(z)") |
| 162 | +s1_z = sampler.extract("s1","z") |
| 163 | +s1_phi = sampler.extract("s1","phi(z)") |
| 164 | +np.testing.assert_array_equal(x.numpy(), s1_z.numpy()) |
| 165 | +np.testing.assert_array_equal(x.numpy(), s1_phi.numpy()) |
| 166 | +print(" PASS!") |
| 167 | + |
| 168 | +print(" => Test for: x = 1 = s0.z = s0.phi(z)") |
| 169 | +s0_z = sampler.extract("s0","z") |
| 170 | +s0_phi = sampler.extract("s0","phi(z)") |
| 171 | +np.testing.assert_array_equal(x.numpy(), s0_z.numpy()) |
| 172 | +np.testing.assert_array_equal(x.numpy(), s0_phi.numpy()) |
| 173 | +print(" PASS!") |
| 174 | + |
| 175 | +print(" => Test for: x = x_sample") |
| 176 | +print("Expected: ",x.numpy()) |
| 177 | +print(" Output: ",x_sample.numpy()) |
| 178 | +np.testing.assert_array_equal(x.numpy(), x_sample.numpy()) |
| 179 | +print(" PASS!") |
| 180 | + |
| 181 | + |
| 182 | +# test NGC simulation object |
| 183 | +print("----------------") |
| 184 | +print(" > Checking NGC simulation object:") |
| 185 | +readouts = ngc_model.settle( |
| 186 | + clamped_vars=[("z3","z",x),("z0","z",x)], |
| 187 | + readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")] |
| 188 | + ) |
| 189 | +x_hat = readouts[0][2] |
| 190 | + |
| 191 | +print(" => Test for: 0 = e2.z = e2.phi(z)") |
| 192 | +target_value = np.zeros([1,z2_dim]) |
| 193 | +e2_z = ngc_model.extract("e2","z") |
| 194 | +e2_phi = ngc_model.extract("e2","phi(z)") |
| 195 | +np.testing.assert_array_equal(target_value, e2_z.numpy()) |
| 196 | +np.testing.assert_array_equal(target_value, e2_phi.numpy()) |
| 197 | +print(" PASS!") |
| 198 | + |
| 199 | +print(" => Test for: 0 = e1.z = e1.phi(z)") |
| 200 | +target_value = np.zeros([1,z1_dim]) |
| 201 | +e1_z = ngc_model.extract("e1","z") |
| 202 | +e1_phi = ngc_model.extract("e1","phi(z)") |
| 203 | +np.testing.assert_array_equal(target_value, e1_z.numpy()) |
| 204 | +np.testing.assert_array_equal(target_value, e1_phi.numpy()) |
| 205 | +print(" PASS!") |
| 206 | + |
| 207 | +print(" => Test for: 0 = e0.z = e0.phi(z)") |
| 208 | +target_value = np.zeros([1,z0_dim]) |
| 209 | +e0_z = ngc_model.extract("e0","z") |
| 210 | +e0_phi = ngc_model.extract("e0","phi(z)") |
| 211 | +np.testing.assert_array_equal(target_value, e0_z.numpy()) |
| 212 | +np.testing.assert_array_equal(target_value, e0_phi.numpy()) |
| 213 | +print(" PASS!") |
| 214 | + |
| 215 | +print(" => Test for: x = x_hat") |
| 216 | +print("Expected: ",x.numpy()) |
| 217 | +print(" Output: ",x_hat.numpy()) |
| 218 | +np.testing.assert_array_equal(x.numpy(), x_hat.numpy()) |
| 219 | +print(" PASS!") |
| 220 | + |
| 221 | +print(" => Test for update calculation: all dx should be = 0") |
| 222 | +delta = ngc_model.calc_updates() |
| 223 | +for i in range(len(delta)): |
| 224 | + target_dx = ngc_model.theta[i] * 0 |
| 225 | + dx = delta[i] |
| 226 | + np.testing.assert_array_equal(target_dx.numpy(), dx.numpy()) |
| 227 | +print(" PASS! (for all {} dx calculations)".format(len(delta))) |
| 228 | + |
| 229 | +print("#######################################################################") |
| 230 | + |
| 231 | +print("#######################################################################") |
| 232 | +print(" > Testing a proxy NGC graph w/ untied weights") |
| 233 | +# set up system nodes |
| 234 | +z3 = SNode(name="z3", dim=z3_dim, beta=beta, leak=0, act_fx="identity", |
| 235 | + integrate_kernel=integrate_cfg) |
| 236 | +mu2 = SNode(name="mu2", dim=z2_dim, act_fx="identity", zeta=0.0) |
| 237 | +e2 = ENode(name="e2", dim=z2_dim) |
| 238 | +z2 = SNode(name="z2", dim=z2_dim, beta=beta, leak=0, act_fx="relu", |
| 239 | + integrate_kernel=integrate_cfg) |
| 240 | +mu1 = SNode(name="mu1", dim=z1_dim, act_fx="identity", zeta=0.0) |
| 241 | +e1 = ENode(name="e1", dim=z1_dim) |
| 242 | +z1 = SNode(name="z1", dim=z1_dim, beta=beta, leak=0, act_fx="relu", |
| 243 | + integrate_kernel=integrate_cfg) |
| 244 | +mu0 = SNode(name="mu0", dim=z0_dim, act_fx="identity", zeta=0.0) |
| 245 | +e0 = ENode(name="e0", dim=z0_dim) |
| 246 | +z0 = SNode(name="z0", dim=z0_dim, beta=beta, leak=0.0, |
| 247 | + integrate_kernel=integrate_cfg) |
| 248 | + |
| 249 | +# create cable wiring scheme relating nodes to one another |
| 250 | +z3_mu2 = z3.wire_to(mu2, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 251 | +mu2.wire_to(e2, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 252 | +z2.wire_to(e2, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 253 | +e2.wire_to(z3, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg) |
| 254 | +e2.wire_to(z2, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 255 | + |
| 256 | +z2_mu1 = z2.wire_to(mu1, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 257 | +mu1.wire_to(e1, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 258 | +z1.wire_to(e1, src_var="z", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 259 | +e1.wire_to(z2, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg) |
| 260 | +e1.wire_to(z1, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 261 | + |
| 262 | +z1_mu0 = z1.wire_to(mu0, src_var="phi(z)", dest_var="dz_td", cable_kernel=dcable_cfg) |
| 263 | +mu0.wire_to(e0, src_var="phi(z)", dest_var="pred_mu", cable_kernel=pos_scable_cfg) |
| 264 | +z0.wire_to(e0, src_var="phi(z)", dest_var="pred_targ", cable_kernel=pos_scable_cfg) |
| 265 | +e0.wire_to(z1, src_var="phi(z)", dest_var="dz_bu", cable_kernel=dcable_cfg) |
| 266 | +e0.wire_to(z0, src_var="phi(z)", dest_var="dz_td", cable_kernel=neg_scable_cfg) |
| 267 | + |
| 268 | +# set up update rules and make relevant edges aware of these |
| 269 | +z3_mu2.set_update_rule(preact=(z3,"phi(z)"), postact=(e2,"phi(z)")) |
| 270 | +z2_mu1.set_update_rule(preact=(z2,"phi(z)"), postact=(e1,"phi(z)")) |
| 271 | +z1_mu0.set_update_rule(preact=(z1,"phi(z)"), postact=(e0,"phi(z)")) |
| 272 | + |
| 273 | +# Set up graph - execution cycle/order |
| 274 | +ngc_model = NGCGraph(K=10, name="gncn_t1_ffm") |
| 275 | +ngc_model.proj_update_mag = -1.0 |
| 276 | +ngc_model.proj_weight_mag = -1.0 |
| 277 | +ngc_model.set_cycle(nodes=[z3,z2,z1,z0]) |
| 278 | +ngc_model.set_cycle(nodes=[mu2,mu1,mu0]) |
| 279 | +ngc_model.set_cycle(nodes=[e2,e1,e0]) |
| 280 | +ngc_model.apply_constraints() |
| 281 | + |
| 282 | +print("----------------") |
| 283 | +print(" > Checking NGC simulation object:") |
| 284 | +readouts = ngc_model.settle( |
| 285 | + clamped_vars=[("z3","z",x),("z0","z",x)], |
| 286 | + readout_vars=[("mu0","phi(z)"),("mu1","phi(z)"),("mu2","phi(z)")] |
| 287 | + ) |
| 288 | +x_hat = readouts[0][2] |
| 289 | + |
| 290 | +print(" => Test for: 0 = e2.z = e2.phi(z)") |
| 291 | +target_value = np.zeros([1,z2_dim]) |
| 292 | +e2_z = ngc_model.extract("e2","z") |
| 293 | +e2_phi = ngc_model.extract("e2","phi(z)") |
| 294 | +np.testing.assert_array_equal(target_value, e2_z.numpy()) |
| 295 | +np.testing.assert_array_equal(target_value, e2_phi.numpy()) |
| 296 | +print(" PASS!") |
| 297 | + |
| 298 | +print(" => Test for: 0 = e1.z = e1.phi(z)") |
| 299 | +target_value = np.zeros([1,z1_dim]) |
| 300 | +e1_z = ngc_model.extract("e1","z") |
| 301 | +e1_phi = ngc_model.extract("e1","phi(z)") |
| 302 | +np.testing.assert_array_equal(target_value, e1_z.numpy()) |
| 303 | +np.testing.assert_array_equal(target_value, e1_phi.numpy()) |
| 304 | +print(" PASS!") |
| 305 | + |
| 306 | +print(" => Test for: 0 = e0.z = e0.phi(z)") |
| 307 | +target_value = np.zeros([1,z0_dim]) |
| 308 | +e0_z = ngc_model.extract("e0","z") |
| 309 | +e0_phi = ngc_model.extract("e0","phi(z)") |
| 310 | +np.testing.assert_array_equal(target_value, e0_z.numpy()) |
| 311 | +np.testing.assert_array_equal(target_value, e0_phi.numpy()) |
| 312 | +print(" PASS!") |
| 313 | + |
| 314 | +print(" => Test for: x = x_hat") |
| 315 | +print("Expected: ",x.numpy()) |
| 316 | +print(" Output: ",x_hat.numpy()) |
| 317 | +np.testing.assert_array_equal(x.numpy(), x_hat.numpy()) |
| 318 | +print(" PASS!") |
| 319 | + |
| 320 | +print(" => Test for update calculation: all dx should be = 0") |
| 321 | +delta = ngc_model.calc_updates() |
| 322 | +for i in range(len(delta)): |
| 323 | + target_dx = ngc_model.theta[i] * 0 |
| 324 | + dx = delta[i] |
| 325 | + np.testing.assert_array_equal(target_dx.numpy(), dx.numpy()) |
| 326 | +print(" PASS! (for all {} dx calculations)".format(len(delta))) |
| 327 | + |
| 328 | +print("#######################################################################") |
0 commit comments